From 7ba279a3ff76bd17f25035ced90ea882812dfac8 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Nov 2024 08:32:59 +0900 Subject: [PATCH 01/39] [SPARK-50118][CONNET] Reset isolated state cache when tasks are running ### What changes were proposed in this pull request? This PR proposes to reset the expire timeout of the isolated session. during the tasks running. ### Why are the changes needed? To prevent removal of artifacts for long running tasks. ### Does this PR introduce _any_ user-facing change? Yes. It fixes a bug. When users are running Python UDFs (or Scala UDF) for more than the specific timeout (30 minutes), and other tasks are submitted by other sessions - so the cache removal happens by Guava cache, it removes the artifact directory dedicated for the session. ### How was this patch tested? Manually tested after taking the logic out. It's difficult to write a test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48665 from HyukjinKwon/SPARK-50118. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/executor/Executor.scala | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 72d97bd787007..c299f38526aeb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer -import java.util.{Locale, Properties} +import java.util.{Locale, Properties, Timer, TimerTask} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.locks.ReentrantLock @@ -209,9 +209,10 @@ private[spark] class Executor( // The default isolation group val defaultSessionState: IsolatedSessionState = newSessionState(JobArtifactState("default", None)) + private val cacheExpiryTime = 30 * 60 * 1000 val isolatedSessionCache: Cache[String, IsolatedSessionState] = CacheBuilder.newBuilder() .maximumSize(100) - .expireAfterAccess(30, TimeUnit.MINUTES) + .expireAfterAccess(cacheExpiryTime, TimeUnit.MILLISECONDS) .removalListener(new RemovalListener[String, IsolatedSessionState]() { override def onRemoval( notification: RemovalNotification[String, IsolatedSessionState]): Unit = { @@ -295,6 +296,8 @@ private[spark] class Executor( private val pollOnHeartbeat = if (METRICS_POLLING_INTERVAL_MS > 0) false else true + private val timer = new Timer("executor-state-timer", true) + // Poller for the memory metrics. Visible for testing. private[executor] val metricsPoller = new ExecutorMetricsPoller( env.memoryManager, @@ -445,6 +448,9 @@ private[spark] class Executor( case NonFatal(e) => logWarning("Unable to stop heartbeater", e) } + if (timer != null) { + timer.cancel() + } ShuffleBlockPusher.stop() if (threadPool != null) { threadPool.shutdown() @@ -559,9 +565,17 @@ private[spark] class Executor( override def run(): Unit = { // Classloader isolation + var maybeTimerTask: Option[TimerTask] = None val isolatedSession = taskDescription.artifacts.state match { case Some(jobArtifactState) => - isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState)) + val state = isolatedSessionCache.get( + jobArtifactState.uuid, () => newSessionState(jobArtifactState)) + maybeTimerTask = Some(new TimerTask { + // Resets the expire time till the task ends. + def run(): Unit = isolatedSessionCache.getIfPresent(jobArtifactState.uuid) + }) + maybeTimerTask.foreach(timer.schedule(_, cacheExpiryTime / 10, cacheExpiryTime / 10)) + state case _ => defaultSessionState } @@ -862,6 +876,7 @@ private[spark] class Executor( uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { + maybeTimerTask.foreach(_.cancel) cleanMDCForTask(taskName, mdcProperties) runningTasks.remove(taskId) if (taskStarted) { From d96c623c76015cc20bc19d693d96275dcf39ecf7 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Nov 2024 08:54:28 +0900 Subject: [PATCH 02/39] Revert "[SPARK-50222][PYTHON][FOLLOWUP] Support `spark.submit.appName` in PySpark" This reverts commit b23905ac0c628a32ac5062c2fdb90e1b3564dcde. --- python/pyspark/sql/session.py | 3 --- python/pyspark/tests/test_appsubmit.py | 20 -------------------- 2 files changed, 23 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index a5b76a27b2960..748dd2cafa7c3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -543,12 +543,9 @@ def getOrCreate(self) -> "SparkSession": session = SparkSession._instantiatedSession if session is None or session._sc._jsc is None: - SparkContext._ensure_initialized() sparkConf = SparkConf() for key, value in self._options.items(): sparkConf.set(key, value) - if sparkConf.contains("spark.submit.appName"): - sparkConf.setAppName(sparkConf.get("spark.submit.appName", "")) # This SparkContext may be an existing one. sc = SparkContext.getOrCreate(sparkConf) # Do not update `SparkConf` for existing `SparkContext`, as it's shared diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py index 0645bf2dc64b3..79b6b4fa91a75 100644 --- a/python/pyspark/tests/test_appsubmit.py +++ b/python/pyspark/tests/test_appsubmit.py @@ -293,26 +293,6 @@ def test_user_configuration(self): out, err = proc.communicate() self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) - def test_session(self): - """Make sure spark.submit.appName overrides the appName in script""" - script = self.createTempFile( - "test.py", - """ - |from pyspark.sql import SparkSession - |spark = SparkSession.builder.appName("PythonPi").getOrCreate() - |print(spark.sparkContext.appName) - |spark.stop() - """, - ) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local", "-c", "spark.submit.appName=NEW", script], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("NEW", out.decode("utf-8")) - if __name__ == "__main__": from pyspark.tests.test_appsubmit import * # noqa: F401 From 5432cef483c499e3548a30d608550bac9fce53ec Mon Sep 17 00:00:00 2001 From: jingz-db Date: Tue, 12 Nov 2024 10:43:48 +0900 Subject: [PATCH 03/39] [SPARK-50152][SS] Support handleInitialState with state data source reader ### What changes were proposed in this pull request? This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2. More specifically, this dataframe is coming from state data source reader. Remove the restraints that initialState dataframe can only contains one value row for a grouping key. This is to enable the integration with state data source reader. In flattened state data source reader for composite type, we will have multiple value rows mapping to the same grouping key. For example, we can union dataframe created by state data source reader on a single state variable and union them together and get an output dataframe as initial state for a transformWithState operator like this: ``` +-----------+-----+---------+----------+------------+ |groupingKey|value|listValue|userMapKey|userMapValue| +-----------+-----+---------+----------+------------+ |a |3 |NULL |NULL |NULL | |b |2 |NULL |NULL |NULL | |a |NULL |1 |NULL |NULL | |a |NULL |2 |NULL |NULL | |a |NULL |3 |NULL |NULL | |b |NULL |1 |NULL |NULL | |b |NULL |2 |NULL |NULL | |a |NULL |NULL |a |3 | |b |NULL |NULL |b |2 | +-----------+-----+---------+----------+------------+ ``` ### Why are the changes needed? This change is for supporting initial state handling for integration with state data source reader. ### Does this PR introduce _any_ user-facing change? No. The user API is the same as prior PR: https://github.com/apache/spark/pull/45467 for initial state support without state data source reader. ### How was this patch tested? Unit test cases added in `TransformWithStateWithInitialStateSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48686 from jingz-db/initial-state-reader-integration. Lead-authored-by: jingz-db Co-authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../resources/error/error-conditions.json | 6 - .../sql/streaming/StatefulProcessor.scala | 3 + .../streaming/TransformWithStateExec.scala | 8 +- .../streaming/state/StateStoreErrors.scala | 10 - .../TransformWithStateInitialStateSuite.scala | 286 +++++++++++++++--- 5 files changed, 257 insertions(+), 56 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7ef6feae08452..154fee2eefb79 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4233,12 +4233,6 @@ ], "sqlState" : "42802" }, - "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : { - "message" : [ - "Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=." - ], - "sqlState" : "42802" - }, "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : { "message" : [ "State variable with name has already been defined in the StatefulProcessor." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 719d1e572c20d..55477b4dda0c9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -121,6 +121,9 @@ private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S] /** * Function that will be invoked only in the first batch for users to process initial states. + * The provided initial state can be arbitrary dataframe with the same grouping key schema with + * the input rows, e.g. dataframe from data source reader of existing streaming query + * checkpoint. * * @param key * \- grouping key diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 4f7a10f882453..2b26d18019d12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -271,13 +271,9 @@ case class TransformWithStateExec( ImplicitGroupingKeyTracker.setImplicitKey(keyObj) val initStateObjIter = initStateIter.map(getInitStateValueObj.apply) - var seenInitStateOnKey = false initStateObjIter.foreach { initState => - // cannot re-initialize state on the same grouping key during initial state handling - if (seenInitStateOnKey) { - throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString) - } - seenInitStateOnKey = true + // allow multiple initial state rows on the same grouping key for integration + // with state data source reader with initial state statefulProcessor .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] .handleInitialState(keyObj, initState, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index e4b370e67b018..45ad7e14c52d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -123,11 +123,6 @@ object StateStoreErrors { new StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, handleState) } - def cannotReInitializeStateOnKey(groupingKey: String): - StatefulProcessorCannotReInitializeState = { - new StatefulProcessorCannotReInitializeState(groupingKey) - } - def cannotProvideTTLConfigForTimeMode(stateName: String, timeMode: String): StatefulProcessorCannotAssignTTLInTimeMode = { new StatefulProcessorCannotAssignTTLInTimeMode(stateName, timeMode) @@ -272,11 +267,6 @@ class StatefulProcessorCannotPerformOperationWithInvalidHandleState( messageParameters = Map("operationType" -> operationType, "handleState" -> handleState) ) -class StatefulProcessorCannotReInitializeState(groupingKey: String) - extends SparkUnsupportedOperationException( - errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", - messageParameters = Map("groupingKey" -> groupingKey)) - class StateStoreUnsupportedOperationOnMissingColumnFamily( operationType: String, colFamilyName: String) extends SparkUnsupportedOperationException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 35ac8a4687eb0..360656a76f350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.streaming -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset} +import org.apache.spark.sql.{DataFrame, Dataset, Encoders, KeyValueGroupedDataset} +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider} -import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.functions.{col, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -29,6 +29,21 @@ case class InitInputRow(key: String, action: String, value: Double) case class InputRowForInitialState( key: String, value: Double, entries: List[Double], mapping: Map[Double, Int]) +case class UnionInitialStateRow( + groupingKey: String, + value: Option[Long], + listValue: Option[Long], + userMapKey: Option[String], + userMapValue: Option[Long] +) + +case class UnionUnflattenInitialStateRow( + groupingKey: String, + value: Option[Long], + listValue: Option[Seq[Long]], + mapValue: Option[Map[String, Long]] +) + abstract class StatefulProcessorWithInitialStateTestClass[V] extends StatefulProcessorWithInitialState[ String, InitInputRow, (String, String, Double), V] { @@ -86,6 +101,86 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] } } +/** + * Class that updates all state variables with input rows. Act as a counter - + * keep the count in value state; keep all the occurrences in list state; and + * keep a map of key -> occurrence count in the map state. + */ +abstract class InitialStateWithStateDataSourceBase[V] + extends StatefulProcessorWithInitialState[ + String, String, (String, String), V] { + @transient var _valState: ValueState[Long] = _ + @transient var _listState: ListState[Long] = _ + @transient var _mapState: MapState[String, Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _valState = getHandle.getValueState[Long]("testVal", Encoders.scalaLong, TTLConfig.NONE) + _listState = getHandle.getListState[Long]("testList", Encoders.scalaLong, TTLConfig.NONE) + _mapState = getHandle.getMapState[String, Long]( + "testMap", Encoders.STRING, Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues): Iterator[(String, String)] = { + val curCountValue = if (_valState.exists()) { + _valState.get() + } else { + 0L + } + var cnt = curCountValue + inputRows.foreach { row => + cnt += 1 + _listState.appendValue(cnt) + val mapCurVal = if (_mapState.containsKey(row)) { + _mapState.getValue(row) + } else { + 0 + } + _mapState.updateValue(row, mapCurVal + 1L) + } + _valState.update(cnt) + Iterator.single((key, cnt.toString)) + } + + override def close(): Unit = super.close() +} + +class InitialStatefulProcessorWithStateDataSource + extends InitialStateWithStateDataSourceBase[UnionInitialStateRow] { + override def handleInitialState( + key: String, initialState: UnionInitialStateRow, timerValues: TimerValues): Unit = { + if (initialState.value.isDefined) { + _valState.update(initialState.value.get) + } else if (initialState.listValue.isDefined) { + _listState.appendValue(initialState.listValue.get) + } else if (initialState.userMapKey.isDefined) { + _mapState.updateValue( + initialState.userMapKey.get, initialState.userMapValue.get) + } + } +} + +class InitialStatefulProcessorWithUnflattenStateDataSource + extends InitialStateWithStateDataSourceBase[UnionUnflattenInitialStateRow] { + override def handleInitialState( + key: String, initialState: UnionUnflattenInitialStateRow, timerValues: TimerValues): Unit = { + if (initialState.value.isDefined) { + _valState.update(initialState.value.get) + } else if (initialState.listValue.isDefined) { + _listState.appendList( + initialState.listValue.get.toArray) + } else if (initialState.mapValue.isDefined) { + initialState.mapValue.get.keys.foreach { key => + _mapState.updateValue(key, initialState.mapValue.get.get(key).get) + } + } + } +} + class AccumulateStatefulProcessorWithInitState extends StatefulProcessorWithInitialStateTestClass[(String, Double)] { override def handleInitialState( @@ -398,37 +493,6 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF()) } - test("transformWithStateWithInitialState - " + - "cannot re-initialize state during initial state handling") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName) { - val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1", 50.0)).toDS() - .groupByKey(x => x._1).mapValues(x => x) - val inputData = MemoryStream[InitInputRow] - val query = inputData.toDS() - .groupByKey(x => x.key) - .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeMode.None(), - OutputMode.Append(), - initDf) - - testStream(query, OutputMode.Update())( - AddData(inputData, InitInputRow("k1", "add", 50.0)), - Execute { q => - val e = intercept[Exception] { - q.processAllAvailable() - } - checkError( - exception = e.getCause.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY", - sqlState = Some("42802"), - parameters = Map("groupingKey" -> "init_1") - ) - } - ) - } - } - test("transformWithStateWithInitialState - streaming with processing time timer, " + "can emit expired initial state rows when grouping key is not received for new input rows") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> @@ -503,4 +567,158 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest ) } } + + Seq(true, false).foreach { flattenOption => + Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions => + test("state data source reader dataframe as initial state " + + s"(flatten option=$flattenOption, shuffle partition for 1st stream=${partitions._1}, " + + s"shuffle partition for 1st stream=${partitions._2})") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempPaths(2) { checkpointDirs => + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, partitions._1) + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new InitialStatefulProcessorWithStateDataSource(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDirs(0).getCanonicalPath), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + AddData(inputData, "a", "b", "a"), + CheckNewAnswer(("a", "3"), ("b", "2")) + ) + + // We are trying to mimic a use case where users load all state data rows + // from a previous tws query as initial state and start a new tws query. + // In this use case, users will need to create a single dataframe with + // all the state rows from different state variables with different schema. + // We can only read from one state variable from one state data source reader + // query, and they are of different schema. We will get one dataframe from each + // state variable, and we union them together into a single dataframe. + val valueDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testVal") + .load() + .drop("partition_id") + + val listDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testList") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption) + .load() + .drop("partition_id") + + val mapDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(0).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testMap") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption) + .load() + .drop("partition_id") + + // create a df where each row contains all value, list, map state rows + // fill the missing column with null. + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, partitions._2) + val inputData2 = MemoryStream[String] + val query = startQueryWithDataSourceDataframeAsInitState( + flattenOption, valueDf, listDf, mapDf, inputData2) + + testStream(query, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDirs(1).getCanonicalPath), + // check initial state is updated for state vars + AddData(inputData2, "c"), + CheckNewAnswer(("c", "1")), + Execute { _ => + val valueDf2 = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testVal") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption) + .load() + .drop("partition_id") + .filter(col("key.value") =!= "c") + + val listDf2 = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testList") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption) + .load() + .drop("partition_id") + .filter(col("key.value") =!= "c") + + val mapDf2 = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDirs(1).getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "testMap") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, flattenOption) + .load() + .drop("partition_id") + .filter(col("key.value") =!= "c") + + checkAnswer(valueDf, valueDf2) + checkAnswer(listDf, listDf2) + checkAnswer(mapDf, mapDf2) + } + ) + } + } + } + } + } + + private def startQueryWithDataSourceDataframeAsInitState( + flattenOption: Boolean, + valDf: DataFrame, + listDf: DataFrame, + mapDf: DataFrame, + inputData: MemoryStream[String]): DataFrame = { + if (flattenOption) { + // when we read the state rows with flattened option set to true, values of a composite + // state variable will be flattened into multiple rows where each row is a + // key -> single value pair + val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS value") + val flattenListDf = listDf + .selectExpr("key.value AS groupingKey", "list_element.value AS listValue") + val flattenMapDf = mapDf + .selectExpr( + "key.value AS groupingKey", + "user_map_key.value AS userMapKey", + "user_map_value.value AS userMapValue") + val df_joined = + valueDf.unionByName(flattenListDf, true) + .unionByName(flattenMapDf, true) + val kvDataSet = inputData.toDS().groupByKey(x => x) + val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x => x.groupingKey) + (kvDataSet.transformWithState( + new InitialStatefulProcessorWithStateDataSource(), + TimeMode.None(), OutputMode.Append(), initDf).toDF()) + } else { + // when we read the state rows with flattened option set to false, values of a composite + // state variable will be composed into a single row of list/map type + val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value AS value") + val unflattenListDf = listDf + .selectExpr("key.value AS groupingKey", + "list_value.value as listValue") + val unflattenMapDf = mapDf + .selectExpr( + "key.value AS groupingKey", + "map_from_entries(transform(map_entries(map_value), x -> " + + "struct(x.key.value, x.value.value))) as mapValue") + val df_joined = + valueDf.unionByName(unflattenListDf, true) + .unionByName(unflattenMapDf, true) + val kvDataSet = inputData.toDS().groupByKey(x => x) + val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x => x.groupingKey) + kvDataSet.transformWithState( + new InitialStatefulProcessorWithUnflattenStateDataSource(), + TimeMode.None(), OutputMode.Append(), initDf).toDF() + } + } } From 0a66c371cd4c49d4c464f06078104eac0232a3d6 Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Tue, 12 Nov 2024 10:58:25 +0900 Subject: [PATCH 04/39] [SPARK-50130][SQL][PYTHON] Add DataFrame APIs for scalar and exists subqueries ### What changes were proposed in this pull request? Adds the following DataFrame APIs for subqueries to Spark Classic. - `scalar()` - `exists()` Also, add `outer()` to `Column` to specify outer references. #### Examples: For the following tables `l` and `r`: ```py >>> spark.table("l").printSchema() root |-- a: long (nullable = true) |-- b: double (nullable = true) >>> spark.table("r").printSchema() root |-- c: long (nullable = true) |-- d: double (nullable = true) ``` ```py from pyspark.sql import functions as sf # select * from l where b < (select max(d) from r where a = c) spark.table("l").where( sf.col("b") < ( spark.table("r") .where(sf.col("a").outer() == sf.col("c")) .select(sf.max("d")) .scalar() ) ) # select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1 spark.table("l").select( "a", ( spark.table("l") .where(sf.col("a") == sf.col("a").outer()) .select(sf.sum("b")) .scalar() .alias("sum_b") ), ) # select * from l where exists (select * from r where l.a = r.c) spark.table("l").where( spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() ) ``` ### Why are the changes needed? Subquery APIs are missing in DataFrame API. ### Does this PR introduce _any_ user-facing change? Yes, new DataFrame APIs for subqueries will be available. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48664 from ueshin/issues/SPARK-50130/scalar_exists. Authored-by: Takuya Ueshin Signed-off-by: Hyukjin Kwon --- .../resources/error/error-conditions.json | 6 + .../scala/org/apache/spark/sql/Dataset.scala | 13 + dev/sparktestsupport/modules.py | 2 + python/pyspark/sql/classic/column.py | 4 + python/pyspark/sql/classic/dataframe.py | 9 + python/pyspark/sql/column.py | 14 + python/pyspark/sql/connect/column.py | 14 +- python/pyspark/sql/connect/dataframe.py | 18 + python/pyspark/sql/dataframe.py | 161 ++++++ .../sql/tests/connect/test_parity_subquery.py | 38 ++ python/pyspark/sql/tests/test_subquery.py | 487 ++++++++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 17 +- .../org/apache/spark/sql/api/Dataset.scala | 27 + .../spark/sql/internal/columnNodes.scala | 18 + .../sql/catalyst/analysis/Analyzer.scala | 24 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 21 +- .../analysis/ColumnResolutionHelper.scala | 41 +- .../sql/catalyst/analysis/unresolved.scala | 40 ++ .../sql/catalyst/expressions/Expression.scala | 16 +- .../sql/catalyst/expressions/subquery.scala | 15 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 15 +- .../sql/catalyst/trees/TreePatterns.scala | 5 + .../scala/org/apache/spark/sql/Dataset.scala | 67 ++- .../spark/sql/execution/QueryExecution.scala | 3 + .../adaptive/PlanAdaptiveSubqueries.scala | 2 +- .../sql/internal/columnNodeSupport.scala | 13 + .../spark/sql/DataFrameSubquerySuite.scala | 367 +++++++++++++ 29 files changed, 1403 insertions(+), 58 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_subquery.py create mode 100644 python/pyspark/sql/tests/test_subquery.py create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 154fee2eefb79..987fc706f7c0b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4691,6 +4691,12 @@ ], "sqlState" : "42KD9" }, + "UNANALYZABLE_EXPRESSION" : { + "message" : [ + "The plan contains an unanalyzable expression that holds the analysis." + ], + "sqlState" : "03000" + }, "UNBOUND_SQL_PARAMETER" : { "message" : [ "Found the unbound parameter: . Please, fix `args` and provide a mapping of the parameter to either a SQL literal or collection constructor functions such as `map()`, `array()`, `struct()`." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 5e50e34e8c35d..631e9057f8d15 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -585,6 +585,19 @@ class Dataset[T] private[sql] ( def transpose(): DataFrame = buildTranspose(Seq.empty) + // TODO(SPARK-50134): Support scalar Subquery API in Spark Connect + // scalastyle:off not.implemented.error.usage + /** @inheritdoc */ + def scalar(): Column = { + ??? + } + + /** @inheritdoc */ + def exists(): Column = { + ??? + } + // scalastyle:on not.implemented.error.usage + /** @inheritdoc */ def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 701ebb54dbbf2..b8702113a26c7 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -545,6 +545,7 @@ def __hash__(self): "pyspark.sql.tests.streaming.test_streaming_foreach", "pyspark.sql.tests.streaming.test_streaming_foreach_batch", "pyspark.sql.tests.streaming.test_streaming_listener", + "pyspark.sql.tests.test_subquery", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", "pyspark.sql.tests.test_udf_profiler", @@ -1044,6 +1045,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_observation", "pyspark.sql.tests.connect.test_parity_repartition", "pyspark.sql.tests.connect.test_parity_stat", + "pyspark.sql.tests.connect.test_parity_subquery", "pyspark.sql.tests.connect.test_parity_types", "pyspark.sql.tests.connect.test_parity_column", "pyspark.sql.tests.connect.test_parity_readwriter", diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index 931378a08187f..c08eac7f6a049 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -605,6 +605,10 @@ def over(self, window: "WindowSpec") -> ParentColumn: jc = self._jc.over(window._jspec) return Column(jc) + def outer(self) -> ParentColumn: + jc = self._jc.outer() + return Column(jc) + def __nonzero__(self) -> None: raise PySparkValueError( errorClass="CANNOT_CONVERT_COLUMN_INTO_BOOL", diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index fad3fac9890b7..169755c753907 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -42,6 +42,7 @@ from pyspark.resource import ResourceProfile from pyspark._globals import _NoValueType from pyspark.errors import ( + AnalysisException, PySparkTypeError, PySparkValueError, PySparkIndexError, @@ -214,6 +215,8 @@ def schema(self) -> StructType: self._schema = cast( StructType, _parse_datatype_json_string(self._jdf.schema().json()) ) + except AnalysisException as e: + raise e except Exception as e: raise PySparkValueError( errorClass="CANNOT_PARSE_DATATYPE", @@ -1783,6 +1786,12 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataF else: return DataFrame(self._jdf.transpose(), self.sparkSession) + def scalar(self) -> Column: + return Column(self._jdf.scalar()) + + def exists(self) -> Column: + return Column(self._jdf.exists()) + @property def executionInfo(self) -> Optional["ExecutionInfo"]: raise PySparkValueError( diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index b4c14d98f4ccd..06dd2860fe406 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -1521,6 +1521,20 @@ def over(self, window: "WindowSpec") -> "Column": """ ... + @dispatch_col_method + def outer(self) -> "Column": + """ + Mark this column reference as an outer reference for subqueries. + + .. versionadded:: 4.0.0 + + See Also + -------- + pyspark.sql.dataframe.DataFrame.scalar + pyspark.sql.dataframe.DataFrame.exists + """ + ... + @dispatch_col_method def __nonzero__(self) -> None: ... diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index be0d054edfb46..e840081146340 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -30,7 +30,12 @@ ) from pyspark.sql.column import Column as ParentColumn -from pyspark.errors import PySparkTypeError, PySparkAttributeError, PySparkValueError +from pyspark.errors import ( + PySparkTypeError, + PySparkAttributeError, + PySparkValueError, + PySparkNotImplementedError, +) from pyspark.sql.types import DataType from pyspark.sql.utils import enum_to_value @@ -454,6 +459,13 @@ def over(self, window: "WindowSpec") -> ParentColumn: # type: ignore[override] return Column(WindowExpression(windowFunction=self._expr, windowSpec=window)) + def outer(self) -> ParentColumn: + # TODO(SPARK-50134): Implement this method + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "outer()"}, + ) + def isin(self, *cols: Any) -> ParentColumn: if len(cols) == 1 and isinstance(cols[0], (list, set)): _cols = list(cols[0]) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 1a9894b6fac54..e85efeb592dff 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1784,6 +1784,20 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataF self._session, ) + def scalar(self) -> Column: + # TODO(SPARK-50134): Implement this method + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "scalar()"}, + ) + + def exists(self) -> Column: + # TODO(SPARK-50134): Implement this method + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "exists()"}, + ) + @property def schema(self) -> StructType: # Schema caching is correct in most cases. Connect is lazy by nature. This means that @@ -2248,6 +2262,10 @@ def _test() -> None: del pyspark.sql.dataframe.DataFrame.toJSON.__doc__ del pyspark.sql.dataframe.DataFrame.rdd.__doc__ + # TODO(SPARK-50134): Support subquery in connect + del pyspark.sql.dataframe.DataFrame.scalar.__doc__ + del pyspark.sql.dataframe.DataFrame.exists.__doc__ + globs["spark"] = ( PySparkSession.builder.appName("sql.connect.dataframe tests") .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 62f2129e5be62..8a5b982bc7f23 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6476,6 +6476,167 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> "DataFrame" """ ... + def scalar(self) -> Column: + """ + Return a `Column` object for a SCALAR Subquery containing exactly one row and one column. + + The `scalar()` method is useful for extracting a `Column` object that represents a scalar + value from a DataFrame, especially when the DataFrame results from an aggregation or + single-value computation. This returned `Column` can then be used directly in `select` + clauses or as predicates in filters on the outer DataFrame, enabling dynamic data filtering + and calculations based on scalar values. + + .. versionadded:: 4.0.0 + + Returns + ------- + :class:`Column` + A `Column` object representing a SCALAR subquery. + + Examples + -------- + Setup a sample DataFrame. + + >>> data = [ + ... (1, "Alice", 45000, 101), (2, "Bob", 54000, 101), (3, "Charlie", 29000, 102), + ... (4, "David", 61000, 102), (5, "Eve", 48000, 101), + ... ] + >>> employees = spark.createDataFrame(data, ["id", "name", "salary", "department_id"]) + + Example 1 (non-correlated): Filter for employees with salary greater than the average + salary. + + >>> from pyspark.sql import functions as sf + >>> employees.where( + ... sf.col("salary") > employees.select(sf.avg("salary")).scalar() + ... ).select("name", "salary", "department_id").show() + +-----+------+-------------+ + | name|salary|department_id| + +-----+------+-------------+ + | Bob| 54000| 101| + |David| 61000| 102| + | Eve| 48000| 101| + +-----+------+-------------+ + + Example 2 (correlated): Filter for employees with salary greater than the average salary + in their department. + + >>> from pyspark.sql import functions as sf + >>> employees.where( + ... sf.col("salary") + ... > employees.where(sf.col("department_id") == sf.col("department_id").outer()) + ... .select(sf.avg("salary")).scalar() + ... ).select("name", "salary", "department_id").show() + +-----+------+-------------+ + | name|salary|department_id| + +-----+------+-------------+ + | Bob| 54000| 101| + |David| 61000| 102| + +-----+------+-------------+ + + Example 3 (in select): Select the name, salary, and the proportion of the salary in the + department. + + >>> from pyspark.sql import functions as sf + >>> employees.select( + ... "name", "salary", "department_id", + ... sf.format_number( + ... sf.lit(100) * sf.col("salary") / + ... employees.where(sf.col("department_id") == sf.col("department_id").outer()) + ... .select(sf.sum("salary")).scalar().alias("avg_salary"), + ... 1 + ... ).alias("salary_proportion_in_department") + ... ).show() + +-------+------+-------------+-------------------------------+ + | name|salary|department_id|salary_proportion_in_department| + +-------+------+-------------+-------------------------------+ + | Alice| 45000| 101| 30.6| + | Bob| 54000| 101| 36.7| + |Charlie| 29000| 102| 32.2| + | Eve| 48000| 101| 32.7| + | David| 61000| 102| 67.8| + +-------+------+-------------+-------------------------------+ + """ + ... + + def exists(self) -> Column: + """ + Return a `Column` object for an EXISTS Subquery. + + The `exists` method provides a way to create a boolean column that checks for the presence + of related records in a subquery. When applied within a `DataFrame`, this method allows you + to filter rows based on whether matching records exist in the related dataset. The resulting + `Column` object can be used directly in filtering conditions or as a computed column. + + .. versionadded:: 4.0.0 + + Returns + ------- + :class:`Column` + A `Column` object representing an EXISTS subquery + + Examples + -------- + Setup sample data for customers and orders. + + >>> data_customers = [ + ... (101, "Alice", "USA"), (102, "Bob", "Canada"), (103, "Charlie", "USA"), + ... (104, "David", "Australia") + ... ] + >>> data_orders = [ + ... (1, 101, "2023-01-15", 250), (2, 102, "2023-01-20", 300), + ... (3, 103, "2023-01-25", 400), (4, 101, "2023-02-05", 150) + ... ] + >>> customers = spark.createDataFrame( + ... data_customers, ["customer_id", "customer_name", "country"]) + >>> orders = spark.createDataFrame( + ... data_orders, ["order_id", "customer_id", "order_date", "total_amount"]) + + Example 1: Filter for customers who have placed at least one order. + + >>> from pyspark.sql import functions as sf + >>> customers.where( + ... orders.where(sf.col("customer_id") == sf.col("customer_id").outer()).exists() + ... ).orderBy("customer_id").show() + +-----------+-------------+-------+ + |customer_id|customer_name|country| + +-----------+-------------+-------+ + | 101| Alice| USA| + | 102| Bob| Canada| + | 103| Charlie| USA| + +-----------+-------------+-------+ + + Example 2: Filter for customers who have never placed an order. + + >>> from pyspark.sql import functions as sf + >>> customers.where( + ... ~orders.where(sf.col("customer_id") == sf.col("customer_id").outer()).exists() + ... ).orderBy("customer_id").show() + +-----------+-------------+---------+ + |customer_id|customer_name| country| + +-----------+-------------+---------+ + | 104| David|Australia| + +-----------+-------------+---------+ + + Example 3: Find Orders from Customers in the USA. + + >>> from pyspark.sql import functions as sf + >>> orders.where( + ... customers.where( + ... (sf.col("customer_id") == sf.col("customer_id").outer()) + ... & (sf.col("country") == "USA") + ... ).exists() + ... ).orderBy("order_id").show() + +--------+-----------+----------+------------+ + |order_id|customer_id|order_date|total_amount| + +--------+-----------+----------+------------+ + | 1| 101|2023-01-15| 250| + | 3| 103|2023-01-25| 400| + | 4| 101|2023-02-05| 150| + +--------+-----------+----------+------------+ + """ + ... + @property def executionInfo(self) -> Optional["ExecutionInfo"]: """ diff --git a/python/pyspark/sql/tests/connect/test_parity_subquery.py b/python/pyspark/sql/tests/connect/test_parity_subquery.py new file mode 100644 index 0000000000000..1cba3a7d49956 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_subquery.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_subquery import SubqueryTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +@unittest.skip("TODO(SPARK-50134): Support subquery in connect") +class SubqueryParityTests(SubqueryTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_subquery import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_subquery.py b/python/pyspark/sql/tests/test_subquery.py new file mode 100644 index 0000000000000..7d50d0959c215 --- /dev/null +++ b/python/pyspark/sql/tests/test_subquery.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.errors import AnalysisException, QueryContextType, SparkRuntimeException +from pyspark.sql import functions as sf +from pyspark.testing import assertDataFrameEqual +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SubqueryTestsMixin: + @property + def df1(self): + return self.spark.createDataFrame( + [ + (1, 1.0), + (1, 2.0), + (2, 1.0), + (2, 2.0), + (3, 3.0), + (None, None), + (None, 5.0), + (6, None), + ], + ["a", "b"], + ) + + @property + def df2(self): + return self.spark.createDataFrame( + [(2, 3.0), (2, 3.0), (3, 2.0), (4, 1.0), (None, None), (None, 5.0), (6, None)], + ["c", "d"], + ) + + def test_unanalyzable_expression(self): + sub = self.spark.range(1).where(sf.col("id") == sf.col("id").outer()) + + with self.assertRaises(AnalysisException) as pe: + sub.schema + + self.check_error( + exception=pe.exception, + errorClass="UNANALYZABLE_EXPRESSION", + messageParameters={"expr": '"outer(id)"'}, + query_context_type=QueryContextType.DataFrame, + fragment="outer", + ) + + def test_simple_uncorrelated_scalar_subquery(self): + assertDataFrameEqual( + self.spark.range(1).select(self.spark.range(1).select(sf.lit(1)).scalar().alias("b")), + self.spark.sql("""select (select 1 as b) as b"""), + ) + + assertDataFrameEqual( + self.spark.range(1).select( + self.spark.range(1) + .select(self.spark.range(1).select(sf.lit(1)).scalar() + 1) + .scalar() + + 1 + ), + self.spark.sql("""select (select (select 1) + 1) + 1"""), + ) + + # string type + assertDataFrameEqual( + self.spark.range(1).select(self.spark.range(1).select(sf.lit("s")).scalar().alias("b")), + self.spark.sql("""select (select 's' as s) as b"""), + ) + + # 0 rows + assertDataFrameEqual( + self.spark.range(1).select( + self.spark.range(1).select(sf.lit("s")).limit(0).scalar().alias("b") + ), + self.spark.sql("""select (select 's' as s limit 0) as b"""), + ) + + def test_uncorrelated_scalar_subquery_with_view(self): + with self.tempView("subqueryData"): + df = self.spark.createDataFrame( + [(1, "one"), (2, "two"), (3, "three")], ["key", "value"] + ) + df.createOrReplaceTempView("subqueryData") + + assertDataFrameEqual( + self.spark.range(1).select( + self.spark.table("subqueryData") + .select("key") + .where(sf.col("key") > 2) + .orderBy("key") + .limit(1) + .scalar() + + 1 + ), + self.spark.sql( + """ + select (select key from subqueryData where key > 2 order by key limit 1) + 1 + """ + ), + ) + + assertDataFrameEqual( + self.spark.range(1).select( + (-self.spark.table("subqueryData").select(sf.max("key")).scalar()).alias( + "negative_max_key" + ) + ), + self.spark.sql( + """select -(select max(key) from subqueryData) as negative_max_key""" + ), + ) + + assertDataFrameEqual( + self.spark.range(1).select( + self.spark.table("subqueryData").select("value").limit(0).scalar() + ), + self.spark.sql("""select (select value from subqueryData limit 0)"""), + ) + + assertDataFrameEqual( + self.spark.range(1).select( + self.spark.table("subqueryData") + .where( + sf.col("key") + == self.spark.table("subqueryData").select(sf.max("key")).scalar() - 1 + ) + .select(sf.min("value")) + .scalar() + ), + self.spark.sql( + """ + select ( + select min(value) from subqueryData + where key = (select max(key) from subqueryData) - 1 + ) + """ + ), + ) + + def test_scalar_subquery_against_local_relations(self): + with self.tempView("t1", "t2"): + self.spark.createDataFrame([(1, 1), (2, 2)], ["c1", "c2"]).createOrReplaceTempView("t1") + self.spark.createDataFrame([(1, 1), (2, 2)], ["c1", "c2"]).createOrReplaceTempView("t2") + + assertDataFrameEqual( + self.spark.table("t1").select( + self.spark.range(1).select(sf.lit(1).alias("col")).scalar() + ), + self.spark.sql("""SELECT (select 1 as col) from t1"""), + ) + + assertDataFrameEqual( + self.spark.table("t1").select(self.spark.table("t2").select(sf.max("c1")).scalar()), + self.spark.sql("""SELECT (select max(c1) from t2) from t1"""), + ) + + assertDataFrameEqual( + self.spark.table("t1").select( + sf.lit(1) + self.spark.range(1).select(sf.lit(1).alias("col")).scalar() + ), + self.spark.sql("""SELECT 1 + (select 1 as col) from t1"""), + ) + + assertDataFrameEqual( + self.spark.table("t1").select( + "c1", self.spark.table("t2").select(sf.max("c1")).scalar() + sf.col("c2") + ), + self.spark.sql("""SELECT c1, (select max(c1) from t2) + c2 from t1"""), + ) + + assertDataFrameEqual( + self.spark.table("t1").select( + "c1", + ( + self.spark.table("t2") + .where(sf.col("c2").outer() == sf.col("c2")) + .select(sf.max("c1")) + .scalar() + ), + ), + self.spark.sql( + """SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1""" + ), + ) + + def test_correlated_scalar_subquery(self): + with self.tempView("l", "r"): + self.df1.createOrReplaceTempView("l") + self.df2.createOrReplaceTempView("r") + + with self.subTest("in where"): + assertDataFrameEqual( + self.spark.table("l").where( + sf.col("b") + < ( + self.spark.table("r") + .where(sf.col("a").outer() == sf.col("c")) + .select(sf.max("d")) + .scalar() + ) + ), + self.spark.sql( + """select * from l where b < (select max(d) from r where a = c)""" + ), + ) + + with self.subTest("in select"): + assertDataFrameEqual( + self.spark.table("l").select( + "a", + ( + self.spark.table("l") + .where(sf.col("a") == sf.col("a").outer()) + .select(sf.sum("b")) + .scalar() + .alias("sum_b") + ), + ), + self.spark.sql( + """select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1""" + ), + ) + + with self.subTest("in select (null safe)"): + assertDataFrameEqual( + self.spark.table("l").select( + "a", + ( + self.spark.table("l") + .where(sf.col("a").eqNullSafe(sf.col("a").outer())) + .select(sf.sum("b")) + .scalar() + .alias("sum_b") + ), + ), + self.spark.sql( + """ + select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1 + """ + ), + ) + + with self.subTest("in aggregate"): + assertDataFrameEqual( + self.spark.table("l") + .groupBy( + "a", + ( + self.spark.table("r") + .where(sf.col("a").outer() == sf.col("c")) + .select(sf.sum("d")) + .scalar() + .alias("sum_d") + ), + ) + .agg({}), + self.spark.sql( + """ + select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2 + """ + ), + ) + + with self.subTest("non-aggregated"): + with self.assertRaises(SparkRuntimeException) as pe: + self.spark.table("l").select( + "a", + ( + self.spark.table("l") + .where(sf.col("a") == sf.col("a").outer()) + .select("b") + .scalar() + ), + ).collect() + + self.check_error( + exception=pe.exception, + errorClass="SCALAR_SUBQUERY_TOO_MANY_ROWS", + messageParameters={}, + ) + + with self.subTest("non-equal"): + assertDataFrameEqual( + self.spark.table("l").select( + "a", + ( + self.spark.table("l") + .where(sf.col("a") < sf.col("a").outer()) + .select(sf.sum("b")) + .scalar() + .alias("sum_b") + ), + ), + self.spark.sql( + """select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1""" + ), + ) + + with self.subTest("disjunctive"): + assertDataFrameEqual( + self.spark.table("l") + .where( + self.spark.table("r") + .where( + ((sf.col("a").outer() == sf.col("c")) & (sf.col("d") == sf.lit(2.0))) + | ((sf.col("a").outer() == sf.col("c")) & (sf.col("d") == sf.lit(1.0))) + ) + .select(sf.count(sf.lit(1))) + .scalar() + > 0 + ) + .select("a"), + self.spark.sql( + """ + select a + from l + where (select count(*) + from r + where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0 + """ + ), + ) + + def test_exists_subquery(self): + with self.tempView("l", "r"): + self.df1.createOrReplaceTempView("l") + self.df2.createOrReplaceTempView("r") + + with self.subTest("EXISTS"): + assertDataFrameEqual( + self.spark.table("l").where( + self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + ), + self.spark.sql( + """select * from l where exists (select * from r where l.a = r.c)""" + ), + ) + + assertDataFrameEqual( + self.spark.table("l").where( + self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + & (sf.col("a") <= sf.lit(2)) + ), + self.spark.sql( + """ + select * from l where exists (select * from r where l.a = r.c) and l.a <= 2 + """ + ), + ) + + with self.subTest("NOT EXISTS"): + assertDataFrameEqual( + self.spark.table("l").where( + ~self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + ), + self.spark.sql( + """select * from l where not exists (select * from r where l.a = r.c)""" + ), + ) + + assertDataFrameEqual( + self.spark.table("l").where( + ~( + self.spark.table("r") + .where( + (sf.col("a").outer() == sf.col("c")) + & (sf.col("b").outer() < sf.col("d")) + ) + .exists() + ) + ), + self.spark.sql( + """ + select * from l + where not exists (select * from r where l.a = r.c and l.b < r.d) + """ + ), + ) + + with self.subTest("EXISTS within OR"): + assertDataFrameEqual( + self.spark.table("l").where( + self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + | self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + ), + self.spark.sql( + """ + select * from l where exists (select * from r where l.a = r.c) + or exists (select * from r where l.a = r.c) + """ + ), + ) + + assertDataFrameEqual( + self.spark.table("l").where( + self.spark.table("r") + .where( + (sf.col("a").outer() == sf.col("c")) + & (sf.col("b").outer() < sf.col("d")) + ) + .exists() + | self.spark.table("r").where(sf.col("a").outer() == sf.col("c")).exists() + ), + self.spark.sql( + """ + select * from l where exists (select * from r where l.a = r.c and l.b < r.d) + or exists (select * from r where l.a = r.c) + """ + ), + ) + + def test_scalar_subquery_with_outer_reference_errors(self): + with self.tempView("l", "r"): + self.df1.createOrReplaceTempView("l") + self.df2.createOrReplaceTempView("r") + + with self.subTest("missing `outer()`"): + with self.assertRaises(AnalysisException) as pe: + self.spark.table("l").select( + "a", + ( + self.spark.table("r") + .where(sf.col("c") == sf.col("a")) + .select(sf.sum("d")) + .scalar() + ), + ).collect() + + self.check_error( + exception=pe.exception, + errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters={"objectName": "`a`", "proposal": "`c`, `d`"}, + query_context_type=QueryContextType.DataFrame, + fragment="col", + ) + + with self.subTest("extra `outer()`"): + with self.assertRaises(AnalysisException) as pe: + self.spark.table("l").select( + "a", + ( + self.spark.table("r") + .where(sf.col("c").outer() == sf.col("a").outer()) + .select(sf.sum("d")) + .scalar() + ), + ).collect() + + self.check_error( + exception=pe.exception, + errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters={"objectName": "`c`", "proposal": "`a`, `b`"}, + query_context_type=QueryContextType.DataFrame, + fragment="outer", + ) + + +class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_subquery import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala index 31ce44eca1684..8498ae04d9a2a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{lit, map} -import org.apache.spark.sql.internal.ColumnNode +import org.apache.spark.sql.internal.{ColumnNode, LazyOuterReference, UnresolvedAttribute} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -1382,6 +1382,21 @@ class Column(val node: ColumnNode) extends Logging { */ def over(): Column = over(Window.spec) + /** + * Marks this column reference as an outer reference for subqueries. + * + * @group subquery + * @since 4.0.0 + */ + def outer(): Column = withOrigin { + node match { + case attr: UnresolvedAttribute if !attr.isMetadataColumn => + Column(LazyOuterReference(attr.nameParts, attr.planId)) + case _ => + throw new IllegalArgumentException( + "Only unresolved attributes can be used as outer references") + } + } } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 416f89ba6f09c..9d41998f11dc6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -1699,6 +1699,33 @@ abstract class Dataset[T] extends Serializable { */ def transpose(): Dataset[Row] + /** + * Return a `Column` object for a SCALAR Subquery containing exactly one row and one column. + * + * The `scalar()` method is useful for extracting a `Column` object that represents a scalar + * value from a DataFrame, especially when the DataFrame results from an aggregation or + * single-value computation. This returned `Column` can then be used directly in `select` + * clauses or as predicates in filters on the outer DataFrame, enabling dynamic data filtering + * and calculations based on scalar values. + * + * @group subquery + * @since 4.0.0 + */ + def scalar(): Column + + /** + * Return a `Column` object for an EXISTS Subquery. + * + * The `exists` method provides a way to create a boolean column that checks for the presence of + * related records in a subquery. When applied within a `DataFrame`, this method allows you to + * filter rows based on whether matching records exist in the related dataset. The resulting + * `Column` object can be used directly in filtering conditions or as a computed column. + * + * @group subquery + * @since 4.0.0 + */ + def exists(): Column + /** * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset * that returns the same result as the input, with the following guarantees:
  • It will diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 979baf12be614..e3cc320a8b00f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -167,6 +167,24 @@ private[sql] object UnresolvedAttribute { apply(unparsedIdentifier, None, false, CurrentOrigin.get) } +/** + * Reference to an attribute in the outer context, used for Subqueries. + * + * @param nameParts + * name of the attribute. + * @param planId + * id of the plan (Dataframe) that produces the attribute. + */ +private[sql] case class LazyOuterReference( + nameParts: Seq[String], + planId: Option[Long] = None, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode { + override private[internal] def normalize(): LazyOuterReference = + copy(planId = None, origin = NO_ORIGIN) + override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") +} + /** * Reference to all columns in a namespace (global, a Dataframe, or a nested struct). * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b94bd31eb3fa1..e87fe447584a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2479,12 +2479,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * can resolve outer references. * * Outer references of the subquery are updated as children of Subquery expression. + * + * If hasExplicitOuterRefs is true, the subquery should have an explicit outer reference, + * instead of common `UnresolvedAttribute`s. In this case, tries to resolve inner and outer + * references separately. */ private def resolveSubQuery( e: SubqueryExpression, - outer: LogicalPlan)( + outer: LogicalPlan, + hasExplicitOuterRefs: Boolean = false)( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { - val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { + val newSubqueryPlan = if (hasExplicitOuterRefs) { + executeSameContext(e.plan).transformAllExpressionsWithPruning( + _.containsPattern(UNRESOLVED_OUTER_REFERENCE)) { + case u: UnresolvedOuterReference => + resolveOuterReference(u.nameParts, outer).getOrElse(u) + } + } else AnalysisContext.withOuterPlan(outer) { executeSameContext(e.plan) } @@ -2509,10 +2520,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => - resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => - resolveSubQuery(e, outer)(Exists(_, _, exprId)) + case s @ ScalarSubquery(sub, _, exprId, _, _, _, _, hasExplicitOuterRefs) + if !sub.resolved => + resolveSubQuery(s, outer, hasExplicitOuterRefs)(ScalarSubquery(_, _, exprId)) + case e @ Exists(sub, _, exprId, _, _, hasExplicitOuterRefs) if !sub.resolved => + resolveSubQuery(e, outer, hasExplicitOuterRefs)(Exists(_, _, exprId)) case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 724014273fed4..c7e5fa9f2b6c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -457,6 +457,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "UNBOUND_SQL_PARAMETER", messageParameters = Map("name" -> p.name)) + case l: LazyAnalysisExpression => + l.failAnalysis( + errorClass = "UNANALYZABLE_EXPRESSION", + messageParameters = Map("expr" -> toSQLExpr(l))) + case _ => }) @@ -1062,6 +1067,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case _ => } + def checkUnresolvedOuterReference(p: LogicalPlan, expr: SubqueryExpression): Unit = { + expr.plan.foreachUp(_.expressions.foreach(_.foreachUp { + case o: UnresolvedOuterReference => + val cols = p.inputSet.toSeq.map(attr => toSQLId(attr.name)).mkString(", ") + o.failAnalysis( + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters = Map("objectName" -> toSQLId(o.name), "proposal" -> cols)) + case _ => + })) + } + + // Check if there is unresolved outer attribute in the subquery plan. + checkUnresolvedOuterReference(plan, expr) + // Validate the subquery plan. checkAnalysis0(expr.plan) @@ -1069,7 +1088,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index d9c723aecbe8e..e869cb281ce05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -221,34 +221,35 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val outerPlan = AnalysisContext.get.outerPlan if (outerPlan.isEmpty) return e - def resolve(nameParts: Seq[String]): Option[Expression] = try { - outerPlan.get match { - // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. - // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will - // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. - case u @ UnresolvedHaving(_, agg: Aggregate) => - agg.resolveChildren(nameParts, conf.resolver) - .orElse(u.resolveChildren(nameParts, conf.resolver)) - .map(wrapOuterReference) - case other => - other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) - } - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - None - } - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) + resolveOuterReference(u.nameParts, outerPlan.get).getOrElse(u) // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with // Aggregate but failed. case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) + resolveOuterReference(t.nameParts, outerPlan.get).getOrElse(t) } } + protected def resolveOuterReference( + nameParts: Seq[String], outerPlan: LogicalPlan): Option[Expression] = try { + outerPlan match { + // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. + // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will + // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. + case u @ UnresolvedHaving(_, agg: Aggregate) => + agg.resolveChildren(nameParts, conf.resolver) + .orElse(u.resolveChildren(nameParts, conf.resolver)) + .map(wrapOuterReference) + case other => + other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) + } + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + None + } + def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = { // The temp variables live in `SYSTEM.SESSION`, and the name can be qualified or not. def maybeTempVariableName(nameParts: Seq[String]): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 389e939bd8e2f..40994f42e71d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -1003,3 +1003,43 @@ case class UnresolvedTranspose( override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedTranspose = copy(child = newChild) } + +case class UnresolvedOuterReference( + nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def exprId: ExprId = throw new UnresolvedException("exprId") + override def dataType: DataType = throw new UnresolvedException("dataType") + override def nullable: Boolean = throw new UnresolvedException("nullable") + override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") + override lazy val resolved = false + + override def toAttribute: Attribute = throw new UnresolvedException("toAttribute") + override def newInstance(): UnresolvedOuterReference = this + + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_OUTER_REFERENCE) +} + +case class LazyOuterReference( + nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable with LazyAnalysisExpression { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def exprId: ExprId = throw new UnresolvedException("exprId") + override def dataType: DataType = throw new UnresolvedException("dataType") + override def nullable: Boolean = throw new UnresolvedException("nullable") + override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") + + override def toAttribute: Attribute = throw new UnresolvedException("toAttribute") + override def newInstance(): NamedExpression = LazyOuterReference(nameParts) + + override def nodePatternsInternal(): Seq[TreePattern] = Seq(LAZY_OUTER_REFERENCE) + + override def prettyName: String = "outer" + override def sql: String = s"$prettyName($name)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index bb32e518ec39a..140335ef8bdd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} -import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LAZY_ANALYSIS_EXPRESSION, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} @@ -404,6 +404,20 @@ trait Unevaluable extends Expression with FoldableUnevaluable { final override def foldable: Boolean = false } +/** + * An expression that cannot be analyzed. These expressions don't live analysis time or after + * and should not be evaluated during query planning and execution. + */ +trait LazyAnalysisExpression extends Expression { + final override lazy val resolved = false + + final override val nodePatterns: Seq[TreePattern] = + Seq(LAZY_ANALYSIS_EXPRESSION) ++ nodePatternsInternal() + + // Subclasses can override this function to provide more TreePatterns. + def nodePatternsInternal(): Seq[TreePattern] = Seq() +} + /** * An expression that gets replaced at runtime (currently by the optimizer) into a different * expression for evaluation. This is mainly used to provide compatibility with other databases. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 0c8253659dd56..bd6f65b61468d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnresolvedOuterReference} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -372,6 +374,13 @@ object SubExprUtils extends PredicateHelper { val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs nonEquivalentGroupByExprs } + + def removeLazyOuterReferences(logicalPlan: LogicalPlan): LogicalPlan = { + logicalPlan.transformAllExpressionsWithPruning( + _.containsPattern(TreePattern.LAZY_OUTER_REFERENCE)) { + case or: LazyOuterReference => UnresolvedOuterReference(or.nameParts) + } + } } /** @@ -398,7 +407,8 @@ case class ScalarSubquery( joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, - needSingleJoin: Option[Boolean] = None) + needSingleJoin: Option[Boolean] = None, + hasExplicitOuterRefs: Boolean = false) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { @@ -567,7 +577,8 @@ case class Exists( outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None) + hint: Option[HintInfo] = None, + hasExplicitOuterRefs: Boolean = false) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a2aa82c25d51..90d9bd5d5d88e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -346,7 +346,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child, _)), - _, _, _, _, mayHaveCountBug, _) + _, _, _, _, mayHaveCountBug, _, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 06fc366ce6bba..195aa7bbeec02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) + case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5a4e9f37c3951..8c82769dbf4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -131,12 +131,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub, _, _, conditions, subHint)) => + case (p, Exists(sub, _, _, conditions, subHint, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftSemi, joinCond, subHint) Project(p.output, join) - case (p, Not(Exists(sub, _, _, conditions, subHint))) => + case (p, Not(Exists(sub, _, _, conditions, subHint, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftAnti, joinCond, subHint) @@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val introducedAttrs = ArrayBuffer.empty[Attribute] val newExprs = exprs.map { e => e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { - case Exists(sub, _, _, conditions, subHint) => + case Exists(sub, _, _, conditions, subHint, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val existenceJoin = ExistenceJoin(exists) val newCondition = conditions.reduceLeftOption(And) @@ -507,7 +507,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case ScalarSubquery(sub, children, exprId, conditions, hint, - mayHaveCountBugOld, needSingleJoinOld) + mayHaveCountBugOld, needSingleJoinOld, _) if children.nonEmpty => def mayHaveCountBugAgg(a: Aggregate): Boolean = { @@ -560,7 +560,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint, Some(mayHaveCountBug), Some(needSingleJoin)) - case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => + case Exists(sub, children, exprId, conditions, hint, _) if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { @@ -818,7 +818,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, - needSingleJoin)) => + needSingleJoin, _)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -1064,7 +1064,8 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) + case s @ ScalarSubquery( + OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 0f1c98b53e0b3..24b787054fb13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -154,6 +154,7 @@ object TreePattern extends Enumeration { val UNRESOLVED_HINT: Value = Value val UNRESOLVED_WINDOW_EXPRESSION: Value = Value val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value + val UNRESOLVED_OUTER_REFERENCE: Value = Value // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value @@ -168,4 +169,8 @@ object TreePattern extends Enumeration { // Execution Plan patterns (alphabetically ordered) val EXCHANGE: Value = Value + + // Lazy analysis expression patterns (alphabetically ordered) + val LAZY_ANALYSIS_EXPRESSION: Value = Value + val LAZY_OUTER_REFERENCE: Value = Value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d43274d761af3..500a4c7c4d9bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder, StructEncoder} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ScalarSubquery => ScalarSubqueryExpr, _} import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelationWithTable import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, ExpressionColumnNode, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ @@ -95,9 +95,14 @@ private[sql] object Dataset { def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = sparkSession.withActive { val qe = sparkSession.sessionState.executePlan(logicalPlan) - qe.assertAnalyzed() - new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) - } + val encoder = if (qe.isLazyAnalysis) { + RowEncoder.encoderFor(new StructType()) + } else { + qe.assertAnalyzed() + RowEncoder.encoderFor(qe.analyzed.schema) + } + new Dataset[Row](qe, encoder) + } def ofRows( sparkSession: SparkSession, @@ -106,8 +111,13 @@ private[sql] object Dataset { sparkSession.withActive { val qe = new QueryExecution( sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode) - qe.assertAnalyzed() - new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) + val encoder = if (qe.isLazyAnalysis) { + RowEncoder.encoderFor(new StructType()) + } else { + qe.assertAnalyzed() + RowEncoder.encoderFor(qe.analyzed.schema) + } + new Dataset[Row](qe, encoder) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ @@ -119,8 +129,13 @@ private[sql] object Dataset { : DataFrame = sparkSession.withActive { val qe = new QueryExecution( sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode) - qe.assertAnalyzed() - new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema)) + val encoder = if (qe.isLazyAnalysis) { + RowEncoder.encoderFor(new StructType()) + } else { + qe.assertAnalyzed() + RowEncoder.encoderFor(qe.analyzed.schema) + } + new Dataset[Row](qe, encoder) } } @@ -230,7 +245,9 @@ class Dataset[T] private[sql]( // A globally unique id of this Dataset. private[sql] val id = Dataset.curId.getAndIncrement() - queryExecution.assertAnalyzed() + if (!queryExecution.isLazyAnalysis) { + queryExecution.assertAnalyzed() + } // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. @@ -244,13 +261,17 @@ class Dataset[T] private[sql]( } @transient private[sql] val logicalPlan: LogicalPlan = { - val plan = queryExecution.commandExecuted - if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { - val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) - dsIds.add(id) - plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) + if (queryExecution.isLazyAnalysis) { + queryExecution.logical + } else { + val plan = queryExecution.commandExecuted + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long]) + dsIds.add(id) + plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds) + } + plan } - plan } /** @@ -982,6 +1003,20 @@ class Dataset[T] private[sql]( ) } + /** @inheritdoc */ + def scalar(): Column = { + Column(ExpressionColumnNode( + ScalarSubqueryExpr(SubExprUtils.removeLazyOuterReferences(logicalPlan), + hasExplicitOuterRefs = true))) + } + + /** @inheritdoc */ + def exists(): Column = { + Column(ExpressionColumnNode( + Exists(SubExprUtils.removeLazyOuterReferences(logicalPlan), + hasExplicitOuterRefs = true))) + } + /** @inheritdoc */ @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index aad905256061d..490184c93620a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} +import org.apache.spark.sql.catalyst.trees.TreePattern.LAZY_ANALYSIS_EXPRESSION import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} @@ -68,6 +69,8 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner + lazy val isLazyAnalysis: Boolean = logical.containsAnyPattern(LAZY_ANALYSIS_EXPRESSION) + def assertAnalyzed(): Unit = { try { analyzed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 5f2638655c37c..35a815d83922d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 476956e58e8e6..64eacba1c6bf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -88,6 +88,9 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres isDistinct = isDistinct, isInternal = isInternal) + case LazyOuterReference(nameParts, planId, _) => + convertLazyOuterReference(nameParts, planId) + case Alias(child, Seq(name), None, _) => expressions.Alias(apply(child), name)( nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) @@ -245,6 +248,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } attribute } + + private def convertLazyOuterReference( + nameParts: Seq[String], + planId: Option[Long]): analysis.LazyOuterReference = { + val lazyOuterReference = analysis.LazyOuterReference(nameParts) + if (planId.isDefined) { + lazyOuterReference.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + lazyOuterReference + } } private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala new file mode 100644 index 0000000000000..fd31efb3054b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Double)](_) + + lazy val l = Seq( + row((1, 2.0)), + row((1, 2.0)), + row((2, 1.0)), + row((2, 1.0)), + row((3, 3.0)), + row((null, null)), + row((null, 5.0)), + row((6, null))).toDF("a", "b") + + lazy val r = Seq( + row((2, 3.0)), + row((2, 3.0)), + row((3, 2.0)), + row((4, 1.0)), + row((null, null)), + row((null, 5.0)), + row((6, null))).toDF("c", "d") + + protected override def beforeAll(): Unit = { + super.beforeAll() + l.createOrReplaceTempView("l") + r.createOrReplaceTempView("r") + } + + test("unanalyzable expression") { + val exception = intercept[AnalysisException] { + spark.range(1).select($"id" === $"id".outer()).schema + } + checkError( + exception, + condition = "UNANALYZABLE_EXPRESSION", + parameters = Map("expr" -> "\"outer(id)\""), + queryContext = + Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern)) + ) + } + + test("simple uncorrelated scalar subquery") { + checkAnswer( + spark.range(1).select( + spark.range(1).select(lit(1)).scalar().as("b") + ), + sql("select (select 1 as b) as b") + ) + + checkAnswer( + spark.range(1).select( + spark.range(1).select(spark.range(1).select(lit(1)).scalar() + 1).scalar() + lit(1) + ), + sql("select (select (select 1) + 1) + 1") + ) + + // string type + checkAnswer( + spark.range(1).select( + spark.range(1).select(lit("s")).scalar().as("b") + ), + sql("select (select 's' as s) as b") + ) + } + + test("uncorrelated scalar subquery should return null if there is 0 rows") { + checkAnswer( + spark.range(1).select( + spark.range(1).select(lit("s")).limit(0).scalar().as("b") + ), + sql("select (select 's' as s limit 0) as b") + ) + } + + test("uncorrelated scalar subquery on a DataFrame generated query") { + withTempView("subqueryData") { + val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") + df.createOrReplaceTempView("subqueryData") + + checkAnswer( + spark.range(1).select( + spark.table("subqueryData") + .select($"key").where($"key" > 2).orderBy($"key").limit(1).scalar() + lit(1) + ), + sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1") + ) + + checkAnswer( + spark.range(1).select( + -spark.table("subqueryData").select(max($"key")).scalar() + ), + sql("select -(select max(key) from subqueryData)") + ) + + checkAnswer( + spark.range(1).select( + spark.table("subqueryData").select($"value").limit(0).scalar() + ), + sql("select (select value from subqueryData limit 0)") + ) + + checkAnswer( + spark.range(1).select( + spark.table("subqueryData") + .where( + $"key" === spark.table("subqueryData").select(max($"key")).scalar() - lit(1) + ).select( + min($"value") + ).scalar() + ), + sql("select (select min(value) from subqueryData" + + " where key = (select max(key) from subqueryData) - 1)") + ) + } + } + + test("SPARK-15677: Queries against local relations with scalar subquery in Select list") { + withTempView("t1", "t2") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + + checkAnswer( + spark.table("t1").select( + spark.range(1).select(lit(1).as("col")).scalar() + ), + sql("SELECT (select 1 as col) from t1") + ) + + checkAnswer( + spark.table("t1").select( + spark.table("t2").select(max($"c1")).scalar() + ), + sql("SELECT (select max(c1) from t2) from t1") + ) + + checkAnswer( + spark.table("t1").select( + lit(1) + spark.range(1).select(lit(1).as("col")).scalar() + ), + sql("SELECT 1 + (select 1 as col) from t1") + ) + + checkAnswer( + spark.table("t1").select( + $"c1", + spark.table("t2").select(max($"c1")).scalar() + $"c2" + ), + sql("SELECT c1, (select max(c1) from t2) + c2 from t1") + ) + + checkAnswer( + spark.table("t1").select( + $"c1", + spark.table("t2").where($"t1.c2".outer() === $"t2.c2").select(max($"c1")).scalar() + ), + sql("SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1") + ) + } + } + + test("EXISTS predicate subquery") { + checkAnswer( + spark.table("l").where( + spark.table("r").where($"a".outer() === $"c").exists() + ), + sql("select * from l where exists (select * from r where l.a = r.c)") + ) + + checkAnswer( + spark.table("l").where( + spark.table("r").where($"a".outer() === $"c").exists() && $"a" <= lit(2) + ), + sql("select * from l where exists (select * from r where l.a = r.c) and l.a <= 2") + ) + } + + test("NOT EXISTS predicate subquery") { + checkAnswer( + spark.table("l").where( + !spark.table("r").where($"a".outer() === $"c").exists() + ), + sql("select * from l where not exists (select * from r where l.a = r.c)") + ) + + checkAnswer( + spark.table("l").where( + !spark.table("r").where($"a".outer() === $"c" && $"b".outer() < $"d").exists() + ), + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)") + ) + } + + test("EXISTS predicate subquery within OR") { + checkAnswer( + spark.table("l").where( + spark.table("r").where($"a".outer() === $"c").exists() || + spark.table("r").where($"a".outer() === $"c").exists() + ), + sql("select * from l where exists (select * from r where l.a = r.c)" + + " or exists (select * from r where l.a = r.c)") + ) + + checkAnswer( + spark.table("l").where( + !spark.table("r").where($"a".outer() === $"c" && $"b".outer() < $"d").exists() || + !spark.table("r").where($"a".outer() === $"c").exists() + ), + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)" + + " or not exists (select * from r where l.a = r.c)") + ) + } + + test("correlated scalar subquery in where") { + checkAnswer( + spark.table("l").where( + $"b" < spark.table("r").where($"a".outer() === $"c").select(max($"d")).scalar() + ), + sql("select * from l where b < (select max(d) from r where a = c)") + ) + } + + test("correlated scalar subquery in select") { + checkAnswer( + spark.table("l").select( + $"a", + spark.table("l").where($"a" === $"a".outer()).select(sum($"b")).scalar().as("sum_b") + ), + sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1") + ) + } + + test("correlated scalar subquery in select (null safe)") { + checkAnswer( + spark.table("l").select( + $"a", + spark.table("l").where($"a" <=> $"a".outer()).select(sum($"b")).scalar().as("sum_b") + ), + sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1") + ) + } + + test("correlated scalar subquery in aggregate") { + checkAnswer( + spark.table("l").groupBy( + $"a", + spark.table("r").where($"a".outer() === $"c").select(sum($"d")).scalar().as("sum_d") + ).agg(Map.empty[String, String]), + sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2") + ) + } + + test("SPARK-34269: correlated subquery with view in aggregate's grouping expression") { + withTable("tr") { + withView("vr") { + r.write.saveAsTable("tr") + sql("create view vr as select * from tr") + checkAnswer( + spark.table("l").groupBy( + $"a", + spark.table("vr").where($"a".outer() === $"c").select(sum($"d")).scalar().as("sum_d") + ).agg(Map.empty[String, String]), + sql("select a, (select sum(d) from vr where a = c) sum_d from l l1 group by 1, 2") + ) + } + } + } + + test("non-aggregated correlated scalar subquery") { + val exception1 = intercept[SparkRuntimeException] { + spark.table("l").select( + $"a", + spark.table("l").where($"a" === $"a".outer()).select($"b").scalar().as("sum_b") + ).collect() + } + checkError( + exception1, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } + + test("non-equal correlated scalar subquery") { + checkAnswer( + spark.table("l").select( + $"a", + spark.table("l").where($"a" < $"a".outer()).select(sum($"b")).scalar().as("sum_b") + ), + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") + ) + } + + test("disjunctive correlated scalar subquery") { + checkAnswer( + spark.table("l").where( + spark.table("r").where( + ($"a".outer() === $"c" && $"d" === 2.0) || + ($"a".outer() === $"c" && $"d" === 1.0) + ).select(count(lit(1))).scalar() > 0 + ).select($"a"), + sql(""" + |select a + |from l + |where (select count(*) + | from r + | where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0 + """.stripMargin) + ) + } + + test("correlated scalar subquery with outer reference errors") { + // Missing `outer()` + val exception1 = intercept[AnalysisException] { + spark.table("l").select( + $"a", + spark.table("r").where($"c" === $"a").select(sum($"d")).scalar() + ).collect() + } + checkError( + exception1, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`a`", "proposal" -> "`c`, `d`"), + queryContext = + Array(ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) + ) + + // Extra `outer()` + val exception2 = intercept[AnalysisException] { + spark.table("l").select( + $"a", + spark.table("r").where($"c".outer() === $"a".outer()).select(sum($"d")).scalar() + ).collect() + } + checkError( + exception2, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`c`", "proposal" -> "`a`, `b`"), + queryContext = + Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern)) + ) + } +} From 4f95a7f4dd2c2d30be549d2a90e52e44046e0726 Mon Sep 17 00:00:00 2001 From: changgyoopark-db Date: Tue, 12 Nov 2024 14:00:05 +0900 Subject: [PATCH 05/39] [SPARK-50260][CONNECT] Refactor and optimize Spark Connect execution and session management ### What changes were proposed in this pull request? Code refactoring. - Replace int with a dedicated case class to represent the state of an execution thread. Minor optimization. - Remove unnecessary steps before actually removing expired executions and sessions. ### Why are the changes needed? Improve code readability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48792 from changgyoopark-db/SPARK-50260. Authored-by: changgyoopark-db Signed-off-by: Hyukjin Kwon --- .../execution/ExecuteThreadRunner.scala | 22 +++++++------ .../SparkConnectExecutionManager.scala | 32 ++++++------------- .../service/SparkConnectSessionManager.scala | 16 +++------- 3 files changed, 27 insertions(+), 43 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 61be2bc4eb994..d27f390a23f95 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.execution -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +41,8 @@ import org.apache.spark.util.Utils private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { /** The thread state. */ - private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted) + private val state: AtomicReference[ThreadStateInfo] = new AtomicReference( + ThreadState.notStarted) // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, @@ -349,17 +350,20 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends private object ThreadState { /** The thread has not started: transition to interrupted or started. */ - val notStarted: Int = 0 + val notStarted: ThreadStateInfo = ThreadStateInfo(0) /** Execution was interrupted: terminal state. */ - val interrupted: Int = 1 + val interrupted: ThreadStateInfo = ThreadStateInfo(1) /** The thread has started: transition to startedInterrupted or completed. */ - val started: Int = 2 + val started: ThreadStateInfo = ThreadStateInfo(2) - /** The thread has started and execution was interrupted: transition to completed. */ - val startedInterrupted: Int = 3 + /** The thread was started and execution has been interrupted: transition to completed. */ + val startedInterrupted: ThreadStateInfo = ThreadStateInfo(3) - /** Execution was completed: terminal state. */ - val completed: Int = 4 + /** Execution has been completed: terminal state. */ + val completed: ThreadStateInfo = ThreadStateInfo(4) } + +/** Represents the state of an execution thread. */ +case class ThreadStateInfo(val transitionState: Int) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index d9eb5438c3886..f750ca6db67a8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -21,7 +21,6 @@ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -160,19 +159,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { - var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]() executions.forEach((_, executeHolder) => { if (executeHolder.sessionHolder.key == key) { - sessionExecutionHolders += executeHolder + val info = executeHolder.getExecuteInfo + logInfo( + log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") + removeExecuteHolder(executeHolder.key, abandoned = true) } }) - - sessionExecutionHolders.foreach { executeHolder => - val info = executeHolder.getExecuteInfo - logInfo( - log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") - removeExecuteHolder(executeHolder.key, abandoned = true) - } } /** Get info about abandoned execution, if there is one. */ @@ -252,30 +246,24 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Visible for testing. private[connect] def periodicMaintenance(timeout: Long): Unit = { + // Find any detached executions that expired and should be removed. logInfo("Started periodic run of SparkConnectExecutionManager maintenance.") - // Find any detached executions that expired and should be removed. - val toRemove = new mutable.ArrayBuffer[ExecuteHolder]() val nowMs = System.currentTimeMillis() - executions.forEach((_, executeHolder) => { executeHolder.lastAttachedRpcTimeMs match { case Some(detached) => if (detached + timeout <= nowMs) { - toRemove += executeHolder + val info = executeHolder.getExecuteInfo + logInfo( + log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " + + log"and expired and will be removed.") + removeExecuteHolder(executeHolder.key, abandoned = true) } case _ => // execution is active } }) - // .. and remove them. - toRemove.foreach { executeHolder => - val info = executeHolder.getExecuteInfo - logInfo( - log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " + - log"and expired and will be removed.") - removeExecuteHolder(executeHolder.key, abandoned = true) - } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index 4ca3a80bfb985..a306856efa33c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -21,7 +21,6 @@ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import scala.collection.mutable import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -226,9 +225,8 @@ class SparkConnectSessionManager extends Logging { private def periodicMaintenance( defaultInactiveTimeoutMs: Long, ignoreCustomTimeout: Boolean): Unit = { - logInfo("Started periodic run of SparkConnectSessionManager maintenance.") // Find any sessions that expired and should be removed. - val toRemove = new mutable.ArrayBuffer[SessionHolder]() + logInfo("Started periodic run of SparkConnectSessionManager maintenance.") def shouldExpire(info: SessionHolderInfo, nowMs: Long): Boolean = { val timeoutMs = if (info.customInactiveTimeoutMs.isDefined && !ignoreCustomTimeout) { @@ -242,15 +240,8 @@ class SparkConnectSessionManager extends Logging { val nowMs = System.currentTimeMillis() sessionStore.forEach((_, sessionHolder) => { - if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) { - toRemove += sessionHolder - } - }) - - // .. and remove them. - toRemove.foreach { sessionHolder => val info = sessionHolder.getSessionHolderInfo - if (shouldExpire(info, System.currentTimeMillis())) { + if (shouldExpire(info, nowMs)) { logInfo( log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + log"and will be closed.") @@ -261,7 +252,8 @@ class SparkConnectSessionManager extends Logging { case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) } } - } + }) + logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") } From 00bff28cd66dc4d68de836c4290291ebc1df572b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 12 Nov 2024 19:27:49 +0800 Subject: [PATCH 06/39] [SPARK-50056][SQL] Codegen Support for ParseUrl (by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `parse_url`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: UrlFunctionsSuite#`*parse_url*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48579 from panbingkun/SPARK-50056. Authored-by: panbingkun Signed-off-by: Wenchen Fan --- .../url/UrlExpressionEvalUtils.scala | 148 +++++++++++++++ .../catalyst/expressions/urlExpressions.scala | 178 ++++-------------- .../expressions/StringExpressionsSuite.scala | 5 +- .../function_parse_url.explain | 2 +- .../function_parse_url_with_key.explain | 2 +- .../function_try_parse_url.explain | 2 +- .../function_try_parse_url_with_key.explain | 2 +- 7 files changed, 194 insertions(+), 145 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala new file mode 100644 index 0000000000000..1eaa25a6bf72c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/url/UrlExpressionEvalUtils.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.url + +import java.net.{URI, URISyntaxException} +import java.util.regex.Pattern + +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.unsafe.types.UTF8String + +case class ParseUrlEvaluator( + url: UTF8String, + extractPart: UTF8String, + pattern: UTF8String, + failOnError: Boolean) { + + import ParseUrlEvaluator._ + + private lazy val cachedUrl: URI = + if (url != null) getUrl(url, failOnError) else null + + private lazy val cachedExtractPartFunc: URI => String = + if (extractPart != null) getExtractPartFunc(extractPart) else null + + private lazy val cachedPattern: Pattern = + if (pattern != null) getPattern(pattern) else null + + private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { + val m = pattern.matcher(query.toString) + if (m.find()) { + UTF8String.fromString(m.group(2)) + } else { + null + } + } + + private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = { + if (cachedExtractPartFunc ne null) { + UTF8String.fromString(cachedExtractPartFunc(url)) + } else { + UTF8String.fromString(getExtractPartFunc(partToExtract)(url)) + } + } + + private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { + if (cachedUrl ne null) { + extractFromUrl(cachedUrl, partToExtract) + } else { + val currentUrl = getUrl(url, failOnError) + if (currentUrl ne null) { + extractFromUrl(currentUrl, partToExtract) + } else { + null + } + } + } + + final def evaluate(url: UTF8String, path: UTF8String): Any = { + parseUrlWithoutKey(url, path) + } + + final def evaluate(url: UTF8String, path: UTF8String, key: UTF8String): Any = { + if (path != QUERY) return null + + val query = parseUrlWithoutKey(url, path) + if (query eq null) return null + + if (cachedPattern ne null) { + extractValueFromQuery(query, cachedPattern) + } else { + extractValueFromQuery(query, getPattern(key)) + } + } +} + +object ParseUrlEvaluator { + private val HOST = UTF8String.fromString("HOST") + private val PATH = UTF8String.fromString("PATH") + private val QUERY = UTF8String.fromString("QUERY") + private val REF = UTF8String.fromString("REF") + private val PROTOCOL = UTF8String.fromString("PROTOCOL") + private val FILE = UTF8String.fromString("FILE") + private val AUTHORITY = UTF8String.fromString("AUTHORITY") + private val USERINFO = UTF8String.fromString("USERINFO") + private val REGEXPREFIX = "(&|^)" + private val REGEXSUBFIX = "=([^&]*)" + + private def getPattern(key: UTF8String): Pattern = { + Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) + } + + private def getUrl(url: UTF8String, failOnError: Boolean): URI = { + try { + new URI(url.toString) + } catch { + case e: URISyntaxException if failOnError => + throw QueryExecutionErrors.invalidUrlError(url, e) + case _: URISyntaxException => null + } + } + + private def getExtractPartFunc(partToExtract: UTF8String): URI => String = { + + // partToExtract match { + // case HOST => _.toURL().getHost + // case PATH => _.toURL().getPath + // case QUERY => _.toURL().getQuery + // case REF => _.toURL().getRef + // case PROTOCOL => _.toURL().getProtocol + // case FILE => _.toURL().getFile + // case AUTHORITY => _.toURL().getAuthority + // case USERINFO => _.toURL().getUserInfo + // case _ => (url: URI) => null + // } + + partToExtract match { + case HOST => _.getHost + case PATH => _.getRawPath + case QUERY => _.getRawQuery + case REF => _.getRawFragment + case PROTOCOL => _.getScheme + case FILE => + (url: URI) => + if (url.getRawQuery ne null) { + url.getRawPath + "?" + url.getRawQuery + } else { + url.getRawPath + } + case AUTHORITY => _.getRawAuthority + case USERINFO => _.getRawUserInfo + case _ => (_: URI) => null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index bf1a788554284..22dcd33937dfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.catalyst.expressions -import java.net.{URI, URISyntaxException, URLDecoder, URLEncoder} +import java.net.{URLDecoder, URLEncoder} import java.nio.charset.StandardCharsets -import java.util.regex.Pattern -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, ObjectType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off line.size.limit @@ -156,19 +154,6 @@ object UrlCodec { } } -object ParseUrl { - private val HOST = UTF8String.fromString("HOST") - private val PATH = UTF8String.fromString("PATH") - private val QUERY = UTF8String.fromString("QUERY") - private val REF = UTF8String.fromString("REF") - private val PROTOCOL = UTF8String.fromString("PROTOCOL") - private val FILE = UTF8String.fromString("FILE") - private val AUTHORITY = UTF8String.fromString("AUTHORITY") - private val USERINFO = UTF8String.fromString("USERINFO") - private val REGEXPREFIX = "(&|^)" - private val REGEXSUBFIX = "=([^&]*)" -} - // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(url, partToExtract[, key]) - This is a special version of `parse_url` that performs the same operation, but returns a NULL value instead of raising an error if the parsing cannot be performed.", @@ -215,8 +200,13 @@ case class TryParseUrl(params: Seq[Expression], replacement: Expression) """, since = "2.0.0", group = "url_funcs") -case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled) - extends Expression with ExpectsInputTypes with CodegenFallback { +case class ParseUrl( + children: Seq[Expression], + failOnError: Boolean = SQLConf.get.ansiEnabled) + extends Expression + with ExpectsInputTypes + with RuntimeReplaceable { + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true @@ -225,29 +215,6 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" - // If the url is a constant, cache the URL object so that we don't need to convert url - // from UTF8String to String to URL for every row. - @transient private lazy val cachedUrl = children(0) match { - case Literal(url: UTF8String, _) if url ne null => getUrl(url) - case _ => null - } - - // If the key is a constant, cache the Pattern object so that we don't need to convert key - // from UTF8String to String to StringBuilder to String to Pattern for every row. - @transient private lazy val cachedPattern = children(2) match { - case Literal(key: UTF8String, _) if key ne null => getPattern(key) - case _ => null - } - - // If the partToExtract is a constant, cache the Extract part function so that we don't need - // to check the partToExtract for every row. - @transient private lazy val cachedExtractPartFunc = children(1) match { - case Literal(part: UTF8String, _) => getExtractPartFunc(part) - case _ => null - } - - import ParseUrl._ - override def checkInputDataTypes(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { throw QueryCompilationErrors.wrongNumArgsError( @@ -258,108 +225,41 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge } } - private def getPattern(key: UTF8String): Pattern = { - Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) - } - - private def getUrl(url: UTF8String): URI = { - try { - new URI(url.toString) - } catch { - case e: URISyntaxException if failOnError => - throw QueryExecutionErrors.invalidUrlError(url, e) - case _: URISyntaxException => null - } - } - - private def getExtractPartFunc(partToExtract: UTF8String): URI => String = { - - // partToExtract match { - // case HOST => _.toURL().getHost - // case PATH => _.toURL().getPath - // case QUERY => _.toURL().getQuery - // case REF => _.toURL().getRef - // case PROTOCOL => _.toURL().getProtocol - // case FILE => _.toURL().getFile - // case AUTHORITY => _.toURL().getAuthority - // case USERINFO => _.toURL().getUserInfo - // case _ => (url: URI) => null - // } - - partToExtract match { - case HOST => _.getHost - case PATH => _.getRawPath - case QUERY => _.getRawQuery - case REF => _.getRawFragment - case PROTOCOL => _.getScheme - case FILE => - (url: URI) => - if (url.getRawQuery ne null) { - url.getRawPath + "?" + url.getRawQuery - } else { - url.getRawPath - } - case AUTHORITY => _.getRawAuthority - case USERINFO => _.getRawUserInfo - case _ => (url: URI) => null - } - } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = + copy(children = newChildren) - private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { - val m = pattern.matcher(query.toString) - if (m.find()) { - UTF8String.fromString(m.group(2)) - } else { - null - } + // If the url is a constant, cache the URL object so that we don't need to convert url + // from UTF8String to String to URL for every row. + @transient private lazy val url = children.head match { + case Literal(url: UTF8String, _) if url ne null => url + case _ => null } - private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = { - if (cachedExtractPartFunc ne null) { - UTF8String.fromString(cachedExtractPartFunc.apply(url)) - } else { - UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url)) - } + // If the partToExtract is a constant, cache the Extract part function so that we don't need + // to check the partToExtract for every row. + @transient private lazy val extractPart = children(1) match { + case Literal(part: UTF8String, _) => part + case _ => null } - private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { - if (cachedUrl ne null) { - extractFromUrl(cachedUrl, partToExtract) - } else { - val currentUrl = getUrl(url) - if (currentUrl ne null) { - extractFromUrl(currentUrl, partToExtract) - } else { - null - } + // If the key is a constant, cache the Pattern object so that we don't need to convert key + // from UTF8String to String to StringBuilder to String to Pattern for every row. + @transient private lazy val pattern = children.size match { + case 3 => children(2) match { + case Literal(key: UTF8String, _) if key ne null => key + case _ => null } + case _ => null } - override def eval(input: InternalRow): Any = { - val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]} - if (evaluated.contains(null)) return null - if (evaluated.size == 2) { - parseUrlWithoutKey(evaluated(0), evaluated(1)) - } else { - // 3-arg, i.e. QUERY with key - assert(evaluated.size == 3) - if (evaluated(1) != QUERY) { - return null - } - - val query = parseUrlWithoutKey(evaluated(0), evaluated(1)) - if (query eq null) { - return null - } - - if (cachedPattern ne null) { - extractValueFromQuery(query, cachedPattern) - } else { - extractValueFromQuery(query, getPattern(evaluated(2))) - } - } - } + @transient + private lazy val evaluator: ParseUrlEvaluator = ParseUrlEvaluator( + url, extractPart, pattern, failOnError) - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl = - copy(children = newChildren) + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[ParseUrlEvaluator])), + "evaluate", + dataType, + children, + children.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa7eafeed485a..1687d614cc5eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1905,7 +1905,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), - Literal("QUERY"), Literal("???")))) + Literal("QUERY"), Literal("???"))).replacement) } // arguments checking @@ -1956,7 +1956,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "inputType" -> "\"INT\""))) // Test escaping of arguments - GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) + GenerateUnsafeProjection.generate( + ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))).replacement :: Nil) } test("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") { diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain index 3c874b5c8b6a4..1f9f3df800b8f 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, false) AS parse_url(g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0)) AS parse_url(g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain index eba1c5c814fe3..900de9c243a83 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_parse_url_with_key.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, g#0, false) AS parse_url(g, g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0, g#0)) AS parse_url(g, g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain index 2fbf751ecf193..87bad58090a08 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, false) AS try_parse_url(g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0)) AS try_parse_url(g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain index 74c4a4985acf2..aed35cfeb7009 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_try_parse_url_with_key.explain @@ -1,2 +1,2 @@ -Project [parse_url(g#0, g#0, g#0, false) AS try_parse_url(g, g, g)#0] +Project [invoke(ParseUrlEvaluator(null,null,null,false).evaluate(g#0, g#0, g#0)) AS try_parse_url(g, g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] From e7ebda6994d31abd98f3a4863d80d9ed2ba1025b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 12 Nov 2024 20:18:27 +0800 Subject: [PATCH 07/39] [SPARK-48273][SQL] Revert "[] Fix late rewrite of PlanWithUnresolvedIdentifier ### What changes were proposed in this pull request? This PR reverts https://github.com/apache/spark/pull/46580 (the tests are left) because it's no longer needed after https://github.com/apache/spark/pull/47501 . The `PlanWithUnresolvedIdentifier` becomes more flatten and all its children will be resolved by the early batch already. ### Why are the changes needed? code cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48786 from cloud-fan/ident. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 9 +++------ .../catalyst/analysis/ResolveIdentifierClause.scala | 11 +++-------- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../spark/sql/catalyst/rules/RuleExecutor.scala | 2 +- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e87fe447584a4..2e82d7ad39c45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -268,7 +268,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TypeCoercion.typeCoercionRules } - private def earlyBatches: Seq[Batch] = Seq( + override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, new SubstituteExecuteImmediate(catalogManager), // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -289,10 +289,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs) - ) - - override def batches: Seq[Batch] = earlyBatches ++ Seq( + KeepLegacyOutputs), Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: ResolveInsertInto :: @@ -339,7 +336,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveTimeZone :: ResolveRandomSeed :: ResolveBinaryArithmetic :: - new ResolveIdentifierClause(earlyBatches) :: + ResolveIdentifierClause :: ResolveUnion :: ResolveRowLevelCommandAssignments :: MoveParameterizedQueriesDown :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index 0620f37fa0db9..0e1e71a658c8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -20,24 +20,19 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE} import org.apache.spark.sql.types.StringType /** * Resolves the identifier expressions and builds the original plans/expressions. */ -class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]) - extends Rule[LogicalPlan] with AliasHelper with EvalHelper { - - private val executor = new RuleExecutor[LogicalPlan] { - override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]] - } +object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => - executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)) + p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children) case u @ UnresolvedWithCTERelations(p, cteRelations) => this.apply(p) match { case u @ UnresolvedRelation(Seq(table), _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index caeb78d20e6a8..044e945d16ad1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5230,7 +5230,7 @@ class AstBuilder extends DataTypeAstBuilder import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val query = Option(ctx.query).map(plan) - withIdentClause(ctx.identifierReference, ident => { + withIdentClause(ctx.identifierReference, query.toSeq, (ident, children) => { if (query.isDefined && ident.length > 1) { val catalogAndNamespace = ident.init throw QueryParsingErrors.addCatalogInCacheTableAsSelectNotAllowedError( @@ -5246,7 +5246,7 @@ class AstBuilder extends DataTypeAstBuilder // alongside the text. // The same rule can be found in CREATE VIEW builder. checkInvalidParameter(query.get, "the query of CACHE TABLE") - CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) + CacheTableAsSelect(ident.head, children.head, source(ctx.query()), isLazy, options) } else { CacheTable( createUnresolvedRelation(ctx.identifierReference, ident, None, writePrivileges = Nil), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 935233d5c85d6..256e1440122d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -147,7 +147,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ - protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] From 916bf533a3ebf0d6569e502295f1d073ca2990c2 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 12 Nov 2024 10:17:51 -0800 Subject: [PATCH 08/39] [SPARK-50259][BUILD] Update Parquet to 1.14.4 ### What changes were proposed in this pull request? Bumping Apache Parquet to 1.14.4 because of a critical bug when writing a dictionary larger than 8kb. For a full overview of bugfixes, see: https://github.com/apache/parquet-java/releases/tag/apache-parquet-1.14.4 ### Why are the changes needed? A serious issue was discovered in the 1.14.x line: https://github.com/apache/parquet-java/releases/tag/apache-parquet-1.14.4-rc2 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48790 from Fokko/fd-bump-parquet-java. Lead-authored-by: Fokko Driesprong Co-authored-by: Fokko Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 12 ++++++------ pom.xml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 048fdb9c219f5..342af10de7083 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -241,12 +241,12 @@ orc-shims/2.0.2//orc-shims-2.0.2.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar -parquet-column/1.14.3//parquet-column-1.14.3.jar -parquet-common/1.14.3//parquet-common-1.14.3.jar -parquet-encoding/1.14.3//parquet-encoding-1.14.3.jar -parquet-format-structures/1.14.3//parquet-format-structures-1.14.3.jar -parquet-hadoop/1.14.3//parquet-hadoop-1.14.3.jar -parquet-jackson/1.14.3//parquet-jackson-1.14.3.jar +parquet-column/1.14.4//parquet-column-1.14.4.jar +parquet-common/1.14.4//parquet-common-1.14.4.jar +parquet-encoding/1.14.4//parquet-encoding-1.14.4.jar +parquet-format-structures/1.14.4//parquet-format-structures-1.14.4.jar +parquet-hadoop/1.14.4//parquet-hadoop-1.14.4.jar +parquet-jackson/1.14.4//parquet-jackson-1.14.4.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar diff --git a/pom.xml b/pom.xml index a5d978e1e981e..a0d9e44ca1634 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 3.8.1 10.16.1.1 - 1.14.3 + 1.14.4 2.0.2 shaded-protobuf 11.0.24 From 4f4eb22f5e9ecc3ae1ed49b5304ac092c993bbb9 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 13 Nov 2024 09:37:03 +0900 Subject: [PATCH 09/39] [SPARK-50234][PYTHON][SQL] Improve error message and test for transpose DataFrame API ### What changes were proposed in this pull request? This PR proposes to improve error message and test for transpose DataFrame API ### Why are the changes needed? To improve error message and negative testing ### Does this PR introduce _any_ user-facing change? No API changes, but user-facing error message is improved ### How was this patch tested? Updated the existing UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48766 from itholic/SPARK-50234. Lead-authored-by: Haejoon Lee Co-authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- python/pyspark/sql/tests/test_dataframe.py | 2 +- .../sql/catalyst/analysis/ResolveTranspose.scala | 5 +++-- .../apache/spark/sql/DataFrameTransposeSuite.scala | 13 +++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 2f53ca38743c1..cd6a57429cfa9 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1041,7 +1041,7 @@ def test_transpose(self): self.check_error( exception=pe.exception, errorClass="TRANSPOSE_NO_LEAST_COMMON_TYPE", - messageParameters={"dt1": "STRING", "dt2": "BIGINT"}, + messageParameters={"dt1": '"STRING"', "dt2": '"BIGINT"'}, ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala index d71237ca15ec3..df45360be8758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, A import org.apache.spark.sql.catalyst.plans.logical.{Filter, Limit, LogicalPlan, Project, Sort, Transpose} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{AtomicType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -57,8 +58,8 @@ class ResolveTranspose(sparkSession: SparkSession) extends Rule[LogicalPlan] { throw new AnalysisException( errorClass = "TRANSPOSE_NO_LEAST_COMMON_TYPE", messageParameters = Map( - "dt1" -> dt1.sql, - "dt2" -> dt2.sql) + "dt1" -> toSQLType(dt1), + "dt2" -> toSQLType(dt2)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala index e6e8b6d5e5b01..51de8553216c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTransposeSuite.scala @@ -58,12 +58,13 @@ class DataFrameTransposeSuite extends QueryTest with SharedSparkSession { assertResult(DoubleType)(transposedDf.schema("x").dataType) assertResult(DoubleType)(transposedDf.schema("y").dataType) - val exception = intercept[AnalysisException] { - person.transpose() - } - assert(exception.getMessage.contains( - "[TRANSPOSE_NO_LEAST_COMMON_TYPE] Transpose requires non-index columns " + - "to share a least common type")) + checkError( + exception = intercept[AnalysisException] { + person.transpose() + }, + condition = "TRANSPOSE_NO_LEAST_COMMON_TYPE", + parameters = Map("dt1" -> "\"STRING\"", "dt2" -> "\"INT\"") + ) } test("enforce ascending order based on index column values for transposed columns") { From 4002a5352d548c9718fd105290a68896f85c0f4d Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 13 Nov 2024 09:52:33 +0900 Subject: [PATCH 10/39] [SPARK-50238][PYTHON] Add Variant Support in PySpark UDFs/UDTFs/UDAFs ### What changes were proposed in this pull request? This PR adds support for the Variant type in PySpark UDFs/UDTFs/UDAFs. Support is added in both modes - arrow and pickle - and support is also added in pandas UDFs. ### Why are the changes needed? After this change, users will be able to use the new Variant data type with UDFs, which is currently prohibited. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use Variants with Python UDFs. ### How was this patch tested? Unit tests in all scenarios - arrow, pickle and pandas ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48770 from harshmotw-db/harsh-motwani_data/variant_udf_3. Authored-by: Harsh Motwani Signed-off-by: Hyukjin Kwon --- .../resources/error/error-conditions.json | 10 - python/pyspark/sql/pandas/serializers.py | 20 +- python/pyspark/sql/pandas/types.py | 30 +- .../sql/tests/pandas/test_pandas_udf.py | 41 ++- .../tests/pandas/test_pandas_udf_scalar.py | 297 ++++++++++++++++-- python/pyspark/sql/tests/test_types.py | 1 + python/pyspark/sql/tests/test_udf.py | 142 ++++++--- python/pyspark/sql/tests/test_udtf.py | 107 +++++++ python/pyspark/sql/types.py | 8 +- .../apache/spark/sql/util/ArrowUtils.scala | 28 +- .../sql/catalyst/expressions/PythonUDF.scala | 35 +-- .../sql/errors/QueryCompilationErrors.scala | 6 - .../sql/execution/arrow/ArrowWriter.scala | 2 +- .../sql/execution/python/EvaluatePython.scala | 11 +- .../sql/execution/python/PythonUDFSuite.scala | 56 ---- .../execution/python/PythonUDTFSuite.scala | 80 +---- 16 files changed, 595 insertions(+), 279 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 987fc706f7c0b..7e7b84f2332b1 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1045,16 +1045,6 @@ "The input of can't be type data." ] }, - "UNSUPPORTED_UDF_INPUT_TYPE" : { - "message" : [ - "UDFs do not support '' as an input data type." - ] - }, - "UNSUPPORTED_UDF_OUTPUT_TYPE" : { - "message" : [ - "UDFs do not support '' as an output data type." - ] - }, "VALUE_OUT_OF_RANGE" : { "message" : [ "The must be between (current value = )." diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9ccef3dba6a4c..5bf07b87400fe 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -31,6 +31,7 @@ ) from pyspark.sql.pandas.types import ( from_arrow_type, + is_variant, to_arrow_type, _create_converter_from_pandas, _create_converter_to_pandas, @@ -420,7 +421,14 @@ def __init__( def arrow_to_pandas(self, arrow_column): import pyarrow.types as types - if self._df_for_struct and types.is_struct(arrow_column.type): + # If the arrow type is struct, return a pandas dataframe where the fields of the struct + # correspond to columns in the DataFrame. However, if the arrow struct is actually a + # Variant, which is an atomic type, treat it as a non-struct arrow type. + if ( + self._df_for_struct + and types.is_struct(arrow_column.type) + and not is_variant(arrow_column.type) + ): import pandas as pd series = [ @@ -505,7 +513,15 @@ def _create_batch(self, series): arrs = [] for s, t in series: - if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t): + # Variants are represented in arrow as structs with additional metadata (checked by + # is_variant). If the data type is Variant, return a VariantVal atomic type instead of + # a dict of two binary values. + if ( + self._struct_in_pandas == "dict" + and t is not None + and pa.types.is_struct(t) + and not is_variant(t) + ): # A pandas UDF should return pd.DataFrame when the return type is a struct type. # If it returns a pd.Series, it should throw an error. if not isinstance(s, pd.DataFrame): diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 57e46901013fe..648af21502864 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -171,7 +171,9 @@ def to_arrow_type( elif type(dt) == VariantType: fields = [ pa.field("value", pa.binary(), nullable=False), - pa.field("metadata", pa.binary(), nullable=False), + # The metadata field is tagged so we can identify that the arrow struct actually + # represents a variant. + pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}), ] arrow_type = pa.struct(fields) else: @@ -221,6 +223,22 @@ def to_arrow_schema( return pa.schema(fields) +def is_variant(at: "pa.DataType") -> bool: + """Check if a PyArrow struct data type represents a variant""" + import pyarrow.types as types + + assert types.is_struct(at) + + return any( + ( + field.name == "metadata" + and b"variant" in field.metadata + and field.metadata[b"variant"] == b"true" + ) + for field in at + ) and any(field.name == "value" for field in at) + + def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType: """Convert pyarrow type to Spark data type.""" import pyarrow.types as types @@ -280,6 +298,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da from_arrow_type(at.item_type, prefer_timestamp_ntz), ) elif types.is_struct(at): + if is_variant(at): + return VariantType() return StructType( [ StructField( @@ -1295,6 +1315,14 @@ def convert_udt(value: Any) -> Any: return convert_udt + elif isinstance(dt, VariantType): + + def convert_variant(variant: Any) -> Any: + assert isinstance(variant, VariantVal) + return {"value": variant.value, "metadata": variant.metadata} + + return convert_variant + return None conv = _converter(data_type) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 228fc30b497cc..4168af64a4d7c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -20,7 +20,14 @@ from typing import cast from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit -from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType +from pyspark.sql.types import ( + DoubleType, + StructType, + StructField, + LongType, + DayTimeIntervalType, + VariantType, +) from pyspark.errors import ParseException, PythonException, PySparkTypeError from pyspark.util import PythonEvalType from pyspark.testing.sqlutils import ( @@ -42,33 +49,65 @@ def test_pandas_udf_basic(self): self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, VariantType()) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, VariantType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf( lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP ) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf( + lambda x: x, StructType([StructField("v", VariantType())]), PandasUDFType.GROUPED_MAP + ) + self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + def test_pandas_udf_basic_with_return_type_string(self): udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "variant", PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "v variant", PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf(lambda x: x, "v variant", functionType=PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + udf = pandas_udf( + lambda x: x, returnType="v variant", functionType=PandasUDFType.GROUPED_MAP + ) + self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + def test_pandas_udf_decorator(self): @pandas_udf(DoubleType()) def foo(x): diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 80613e5f75bee..56a736a20b3af 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -27,7 +27,18 @@ from pyspark import TaskContext from pyspark.util import PythonEvalType from pyspark.sql import Column -from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf, PandasUDFType +from pyspark.sql.functions import ( + array, + col, + expr, + lit, + sum, + struct, + udf, + pandas_udf, + to_json, + PandasUDFType, +) from pyspark.sql.types import ( IntegerType, ByteType, @@ -752,46 +763,272 @@ def check_vectorized_udf_return_scalar(self): def test_udf_with_variant_input(self): df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v") - from pyspark.sql.functions import col - scalar_f = pandas_udf(lambda u: str(u), StringType()) + scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR) iter_f = pandas_udf( - lambda it: map(lambda u: str(u), it), StringType(), PandasUDFType.SCALAR_ITER + lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER ) + expected = [Row(udf="{0}".format(i)) for i in range(10)] + for f in [scalar_f, iter_f]: - with self.assertRaises(AnalysisException) as ae: - df.select(f(col("v"))).collect() - - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - messageParameters={ - "sqlExpr": '"(v)"', - "dataType": "VARIANT", - }, - ) + result = df.select(f(col("v")).alias("udf")).collect() + self.assertEqual(result, expected) def test_udf_with_variant_output(self): - # Corresponds to a JSON string of {"a": "b"}. - returned_variant = VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97])) - scalar_f = pandas_udf(lambda x: returned_variant, VariantType()) + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_f = pandas_udf( + lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), VariantType() + ) iter_f = pandas_udf( - lambda it: map(lambda x: returned_variant, it), VariantType(), PandasUDFType.SCALAR_ITER + lambda it: map( + lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), it + ), + VariantType(), + PandasUDFType.SCALAR_ITER, ) + expected = [Row(udf=i) for i in range(10)] + for f in [scalar_f, iter_f]: - with self.assertRaises(AnalysisException) as ae: - self.spark.range(0, 10).select(f()).collect() - - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - messageParameters={ - "sqlExpr": '"()"', - "dataType": "VARIANT", - }, - ) + result = self.spark.range(10).select(f(col("id")).cast("int").alias("udf")).collect() + self.assertEqual(result, expected) + + def test_udf_with_nested_variant_input(self): + # struct + df = self.spark.range(0, 10).selectExpr( + "named_struct('v', parse_json(cast(id as string))) struct_of_v" + ) + scalar_f = pandas_udf(lambda u: u["v"].apply(str), StringType(), PandasUDFType.SCALAR) + iter_f = pandas_udf( + lambda it: map(lambda u: u["v"].apply(str), it), StringType(), PandasUDFType.SCALAR_ITER + ) + expected = [Row(udf=f"{i}") for i in range(10)] + for f in [scalar_f, iter_f]: + result = df.select(f(col("struct_of_v")).alias("udf")).collect() + self.assertEqual(result, expected) + + # array + df = self.spark.range(0, 10).selectExpr("array(parse_json(cast(id as string))) array_of_v") + scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR) + iter_f = pandas_udf( + lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER + ) + expected = [ + Row(udf=str([VariantVal(bytes([12, i]), bytes([1, 0, 0]))]).format(i)) + for i in range(10) + ] + for f in [scalar_f, iter_f]: + result = df.select(f(col("array_of_v")).alias("udf")).collect() + self.assertEqual(result, expected) + + # map + df = self.spark.range(0, 10).selectExpr("map('v', parse_json(cast(id as string))) map_of_v") + scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR) + iter_f = pandas_udf( + lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER + ) + expected = [ + Row(udf=str({"v": VariantVal(bytes([12, i]), bytes([1, 0, 0]))})) for i in range(10) + ] + for f in [scalar_f, iter_f]: + result = df.select(f(col("map_of_v")).alias("udf")).collect() + self.assertEqual(result, expected) + + def test_udf_with_variant_nested_output(self): + # struct + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_f = pandas_udf( + lambda u: pd.DataFrame( + {"v": u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0])))} + ), + StructType([StructField("v", VariantType(), True)]), + ) + iter_f = pandas_udf( + lambda it: map( + lambda u: pd.DataFrame( + {"v": u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0])))} + ), + it, + ), + StructType([StructField("v", VariantType(), True)]), + PandasUDFType.SCALAR_ITER, + ) + expected = [Row(udf=f"{{{i}}}") for i in range(10)] + for f in [scalar_f, iter_f]: + result = self.spark.range(10).select(f(col("id")).cast("string").alias("udf")).collect() + self.assertEqual(result, expected) + + # array + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_f = pandas_udf( + lambda u: u.apply(lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]), + ArrayType(VariantType()), + ) + iter_f = pandas_udf( + lambda it: map( + lambda u: u.apply(lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]), it + ), + ArrayType(VariantType()), + PandasUDFType.SCALAR_ITER, + ) + expected = [Row(udf=f"[{i}]") for i in range(10)] + for f in [scalar_f, iter_f]: + result = self.spark.range(10).select(f(col("id")).cast("string").alias("udf")).collect() + self.assertEqual(result, expected) + + # map + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_f = pandas_udf( + lambda u: u.apply(lambda i: {"v": VariantVal(bytes([12, i]), bytes([1, 0, 0]))}), + MapType(StringType(), VariantType()), + ) + iter_f = pandas_udf( + lambda it: map( + lambda u: u.apply(lambda i: {"v": VariantVal(bytes([12, i]), bytes([1, 0, 0]))}), it + ), + MapType(StringType(), VariantType()), + PandasUDFType.SCALAR_ITER, + ) + expected = [Row(udf=f"{{v -> {i}}}") for i in range(10)] + for f in [scalar_f, iter_f]: + result = self.spark.range(10).select(f(col("id")).cast("string").alias("udf")).collect() + self.assertEqual(result, expected) + + def test_chained_udfs_with_variant(self): + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_first = pandas_udf( + lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), VariantType() + ) + iter_first = pandas_udf( + lambda it: map( + lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), it + ), + VariantType(), + PandasUDFType.SCALAR_ITER, + ) + scalar_second = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR) + iter_second = pandas_udf( + lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER + ) + + expected = [Row(udf="{0}".format(i)) for i in range(10)] + + for f in [scalar_first, iter_first]: + for s in [scalar_second, iter_second]: + result = self.spark.range(10).select(s(f(col("id"))).alias("udf")).collect() + self.assertEqual(result, expected) + + def test_chained_udfs_with_complex_variant(self): + # Variants representing the int8 value i. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + scalar_first = pandas_udf( + lambda u: u.apply(lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]), + ArrayType(VariantType()), + ) + iter_first = pandas_udf( + lambda it: map( + lambda u: u.apply(lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]), it + ), + ArrayType(VariantType()), + PandasUDFType.SCALAR_ITER, + ) + scalar_second = pandas_udf( + lambda u: u.apply(lambda v: str(v[0])), StringType(), PandasUDFType.SCALAR + ) + iter_second = pandas_udf( + lambda it: map(lambda u: u.apply(lambda v: str(v[0])), it), + StringType(), + PandasUDFType.SCALAR_ITER, + ) + + expected = [Row(udf="{0}".format(i)) for i in range(10)] + + for f in [scalar_first, iter_first]: + for s in [scalar_second, iter_second]: + result = self.spark.range(10).select(s(f(col("id"))).alias("udf")).collect() + self.assertEqual(result, expected) + + def test_udafs_with_variant_input(self): + df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v") + + @pandas_udf("double") + def f(u: pd.Series) -> float: + return u.apply(lambda v: len(str(v))).mean() + + expected = [Row(udf=1)] + result = df.select(f(col("v")).alias("udf")).collect() + self.assertEqual(result, expected) + + def test_udafs_with_complex_variant_input(self): + # struct + df = self.spark.range(0, 10).selectExpr("named_struct('v', parse_json(id::string)) s") + + @pandas_udf("double") + def f(u: pd.Series) -> float: + return u.apply(lambda s: len(str(s["v"]))).mean() + + expected = [Row(udf=1)] + result = df.select(f(col("s")).alias("udf")).collect() + self.assertEqual(result, expected) + + # array + df = self.spark.range(0, 10).selectExpr("array(parse_json(id::string)) a") + + @pandas_udf("double") + def f(u: pd.Series) -> float: + return u.apply(lambda s: len(str(s[0]))).mean() + 1 + + expected = [Row(udf=2)] + result = df.select(f(col("a")).alias("udf")).collect() + self.assertEqual(result, expected) + + # map + df = self.spark.range(0, 10).selectExpr("map('v', parse_json(id::string)) m") + + @pandas_udf("double") + def f(u: pd.Series) -> float: + return u.apply(lambda s: len(str(s["v"]))).mean() + 2 + + expected = [Row(udf=3)] + result = df.select(f(col("m")).alias("udf")).collect() + self.assertEqual(result, expected) + + def test_udafs_with_variant_output(self): + @pandas_udf("variant") + def f(u: pd.Series) -> VariantVal: + return VariantVal(bytes([12, int(u.mean())]), bytes([1, 0, 0])) + + result = self.spark.range(0, 10).select(to_json(f(col("id"))).alias("udf")).collect() + expected = [Row(udf="4")] + self.assertEqual(result, expected) + + def test_udafs_with_complex_variant_output(self): + # struct is not support as the return type for PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + # UDFs yet. + + # array + @pandas_udf("array") + def f(u: pd.Series) -> list: + return [VariantVal(bytes([12, int(u.mean())]), bytes([1, 0, 0]))] + + result = self.spark.range(0, 10).select(to_json(f(col("id"))).alias("udf")).collect() + expected = [Row(udf="[4]")] + self.assertEqual(result, expected) + + # map + @pandas_udf("map") + def f(u: pd.Series) -> dict: + return {"v": VariantVal(bytes([12, int(u.mean())]), bytes([1, 0, 0]))} + + result = self.spark.range(0, 10).select(to_json(f(col("id"))).alias("udf")).collect() + expected = [Row(udf='{"v":4}')] + self.assertEqual(result, expected) def test_vectorized_udf_decorator(self): df = self.spark.range(10) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 3697ea2d07869..9688ed4923737 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2283,6 +2283,7 @@ def schema_from_udf(ddl): ("struct<>", True), ("struct>", True), ("", True), + ("a: int, b: variant", True), ("", False), ("randomstring", False), ("struct", False), diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 879329bd80c0b..78aa2546128a1 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -335,67 +335,125 @@ def test_udf_with_filter_function(self): def test_udf_with_variant_input(self): df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v") - u = udf(lambda u: str(u), StringType()) - with self.assertRaises(AnalysisException) as ae: - df.select(u(col("v"))).collect() - - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - messageParameters={"sqlExpr": '"(v)"', "dataType": "VARIANT"}, - ) + expected = [Row(udf="{0}".format(i)) for i in range(10)] + result = df.select(u(col("v")).alias("udf")).collect() + self.assertEqual(result, expected) def test_udf_with_complex_variant_input(self): + # struct df = self.spark.range(0, 10).selectExpr( "named_struct('v', parse_json(cast(id as string))) struct_of_v" ) - - u = udf(lambda u: str(u), StringType()) - - with self.assertRaises(AnalysisException) as ae: - df.select(u(col("struct_of_v"))).collect() - - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - messageParameters={ - "sqlExpr": '"(struct_of_v)"', - "dataType": "STRUCT", - }, - ) + u = udf(lambda u: str(u["v"]), StringType()) + result = df.select(u(col("struct_of_v"))).collect() + expected = [Row(udf="{0}".format(i)) for i in range(10)] + self.assertEqual(result, expected) + + # array + df = self.spark.range(0, 10).selectExpr("array(parse_json(cast(id as string))) array_of_v") + u = udf(lambda u: str(u[0]), StringType()) + result = df.select(u(col("array_of_v"))).collect() + expected = [Row(udf="{0}".format(i)) for i in range(10)] + self.assertEqual(result, expected) + + # map + df = self.spark.range(0, 10).selectExpr("map('v', parse_json(cast(id as string))) map_of_v") + u = udf(lambda u: str(u["v"]), StringType()) + result = df.select(u(col("map_of_v"))).collect() + expected = [Row(udf="{0}".format(i)) for i in range(10)] + self.assertEqual(result, expected) def test_udf_with_variant_output(self): - # The variant value returned corresponds to a JSON string of {"a": "b"}. + # The variant value returned corresponds to a JSON string of {"a": ""}. + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. u = udf( - lambda: VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97])), + lambda i: VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])), VariantType(), ) + result = self.spark.range(0, 10).select(u(col("id")).cast("string").alias("udf")).collect() + expected = [Row(udf=f'{{"a":"{chr(97 + i)}"}}') for i in range(10)] + self.assertEqual(result, expected) - with self.assertRaises(AnalysisException) as ae: - self.spark.range(0, 10).select(u()).collect() + def test_udf_with_complex_variant_output(self): + # The variant value returned corresponds to a JSON string of {"a": ""}. + # struct + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + u = udf( + lambda i: {"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))}, + StructType([StructField("v", VariantType(), True)]), + ) + result = self.spark.range(0, 10).select(u(col("id")).cast("string").alias("udf")).collect() + expected = [Row(udf=f'{{{{"a":"{chr(97 + i)}"}}}}') for i in range(10)] + self.assertEqual(result, expected) - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - messageParameters={"sqlExpr": '"()"', "dataType": "VARIANT"}, + # array + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + u = udf( + lambda i: [VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))], + ArrayType(VariantType()), ) + result = self.spark.range(0, 10).select(u(col("id")).cast("string").alias("udf")).collect() + expected = [Row(udf=f'[{{"a":"{chr(97 + i)}"}}]') for i in range(10)] + self.assertEqual(result, expected) - def test_udf_with_complex_variant_output(self): - # The variant value returned corresponds to a JSON string of {"a": "b"}. + # map + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. u = udf( - lambda: {"v", VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97]))}, + lambda i: {"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))}, MapType(StringType(), VariantType()), ) - - with self.assertRaises(AnalysisException) as ae: - self.spark.range(0, 10).select(u()).collect() - - self.check_error( - exception=ae.exception, - errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - messageParameters={"sqlExpr": '"()"', "dataType": "MAP"}, + result = self.spark.range(0, 10).select(u(col("id")).cast("string").alias("udf")).collect() + expected = [Row(udf=f'{{v -> {{"a":"{chr(97 + i)}"}}}}') for i in range(10)] + self.assertEqual(result, expected) + + def test_chained_udfs_with_variant(self): + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + udf_first = udf( + lambda i: VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])), + VariantType(), + ) + udf_second = udf(lambda u: str(u), StringType()) + result = ( + self.spark.range(0, 10) + .select(udf_second(udf_first(col("id"))).cast("string").alias("udf")) + .collect() + ) + expected = [Row(udf=f'{{"a":"{chr(97 + i)}"}}') for i in range(10)] + self.assertEqual(result, expected) + + # struct + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + u_first = udf( + lambda i: {"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))}, + StructType([StructField("v", VariantType(), True)]), + ) + u_second = udf(lambda u: str(u["v"]), StringType()) + result = self.spark.range(0, 10).select(u_second(u_first(col("id"))).alias("udf")).collect() + expected = [Row(udf=f'{{"a":"{chr(97 + i)}"}}') for i in range(10)] + self.assertEqual(result, expected) + + # array + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + u_first = udf( + lambda i: [VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))], + ArrayType(VariantType()), + ) + u_second = udf(lambda u: str(u[0]), StringType()) + result = self.spark.range(0, 10).select(u_second(u_first(col("id"))).alias("udf")).collect() + expected = [Row(udf=f'{{"a":"{chr(97 + i)}"}}') for i in range(10)] + self.assertEqual(result, expected) + + # map + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + u_first = udf( + lambda i: {"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))}, + ArrayType(VariantType()), ) + u_second = udf(lambda u: str(u["v"]), StringType()) + result = self.spark.range(0, 10).select(u_second(u_first(col("id"))).alias("udf")).collect() + expected = [Row(udf=f'{{"a":"{chr(97 + i)}"}}') for i in range(10)] + self.assertEqual(result, expected) def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 0d81cb5aec127..f3f993fc6a787 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -55,6 +55,7 @@ StringType, StructField, StructType, + VariantVal, ) from pyspark.testing import assertDataFrameEqual, assertSchemaEqual from pyspark.testing.sqlutils import ( @@ -2530,6 +2531,112 @@ def terminate(self): [Row(current=4, total=4), Row(current=13, total=4), Row(current=20, total=1)], ) + def test_udtf_with_variant_input(self): + @udtf(returnType="i int, s: string") + class TestUDTF: + def eval(self, v): + for i in range(10): + yield i, v.toJson() + + self.spark.udtf.register("test_udtf", TestUDTF) + rows = self.spark.sql('select i, s from test_udtf(parse_json(\'{"a":"b"}\'))').collect() + self.assertEqual(rows, [Row(i=n, s='{"a":"b"}') for n in range(10)]) + + def test_udtf_with_nested_variant_input(self): + # struct + @udtf(returnType="i int, s: string") + class TestUDTFStruct: + def eval(self, v): + for i in range(10): + yield i, v["v"].toJson() + + self.spark.udtf.register("test_udtf_struct", TestUDTFStruct) + rows = self.spark.sql( + "select i, s from test_udtf_struct(named_struct('v', parse_json('{\"a\":\"c\"}')))" + ).collect() + self.assertEqual(rows, [Row(i=n, s='{"a":"c"}') for n in range(10)]) + + # array + @udtf(returnType="i int, s: string") + class TestUDTFArray: + def eval(self, v): + for i in range(10): + yield i, v[0].toJson() + + self.spark.udtf.register("test_udtf_array", TestUDTFArray) + rows = self.spark.sql( + 'select i, s from test_udtf_array(array(parse_json(\'{"a":"d"}\')))' + ).collect() + self.assertEqual(rows, [Row(i=n, s='{"a":"d"}') for n in range(10)]) + + # map + @udtf(returnType="i int, s: string") + class TestUDTFMap: + def eval(self, v): + for i in range(10): + yield i, v["v"].toJson() + + self.spark.udtf.register("test_udtf_map", TestUDTFMap) + rows = self.spark.sql( + "select i, s from test_udtf_map(map('v', parse_json('{\"a\":\"e\"}')))" + ).collect() + self.assertEqual(rows, [Row(i=n, s='{"a":"e"}') for n in range(10)]) + + def test_udtf_with_variant_output(self): + @udtf(returnType="i int, v: variant") + class TestUDTF: + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + def eval(self, n): + for i in range(n): + yield i, VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])) + + self.spark.udtf.register("test_udtf", TestUDTF) + rows = self.spark.sql("select i, to_json(v) from test_udtf(8)").collect() + self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n in range(8)]) + + def test_udtf_with_nested_variant_output(self): + # struct + @udtf(returnType="i int, v: struct") + class TestUDTFStruct: + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + def eval(self, n): + for i in range(n): + yield i, { + "v1": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])) + } + + self.spark.udtf.register("test_udtf_struct", TestUDTFStruct) + rows = self.spark.sql("select i, to_json(v.v1) from test_udtf_struct(8)").collect() + self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(97 + n)}"}}') for n in range(8)]) + + # array + @udtf(returnType="i int, v: array") + class TestUDTFArray: + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + def eval(self, n): + for i in range(n): + yield i, [ + VariantVal(bytes([2, 1, 0, 0, 2, 5, 98 + i]), bytes([1, 1, 0, 1, 97])) + ] + + self.spark.udtf.register("test_udtf_array", TestUDTFArray) + rows = self.spark.sql("select i, to_json(v[0]) from test_udtf_array(8)").collect() + self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(98 + n)}"}}') for n in range(8)]) + + # map + @udtf(returnType="i int, v: map") + class TestUDTFStruct: + # TODO(SPARK-50284): Replace when an easy Python API to construct Variants is created. + def eval(self, n): + for i in range(n): + yield i, { + "v1": VariantVal(bytes([2, 1, 0, 0, 2, 5, 99 + i]), bytes([1, 1, 0, 1, 97])) + } + + self.spark.udtf.register("test_udtf_struct", TestUDTFStruct) + rows = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)").collect() + self.assertEqual(rows, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)]) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f3558c37d09d..03227c8c8760f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1571,11 +1571,9 @@ def fromInternal(self, obj: Dict) -> Optional["VariantVal"]: return None return VariantVal(obj["value"], obj["metadata"]) - def toInternal(self, obj: Any) -> Any: - raise PySparkNotImplementedError( - errorClass="NOT_IMPLEMENTED", - messageParameters={"feature": "VariantType.toInternal"}, - ) + def toInternal(self, variant: Any) -> Any: + assert isinstance(variant, VariantVal) + return {"value": variant.value, "metadata": variant.metadata} class UserDefinedType(DataType): diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 1740cbe2957b8..55d1aff8261d4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -141,9 +141,12 @@ private[sql] object ArrowUtils { case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) case _: VariantType => - val fieldType = new FieldType( - nullable, - ArrowType.Struct.INSTANCE, + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + // The metadata field is tagged with additional metadata so we can identify that the arrow + // struct actually represents a variant schema. + val metadataFieldType = new FieldType( + false, + toArrowType(BinaryType, timeZoneId, largeVarTypes), null, Map("variant" -> "true").asJava) new Field( @@ -151,7 +154,7 @@ private[sql] object ArrowUtils { fieldType, Seq( toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes), - toArrowField("metadata", BinaryType, false, timeZoneId, largeVarTypes)).asJava) + new Field("metadata", metadataFieldType, Seq.empty[Field].asJava)).asJava) case dataType => val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) @@ -159,6 +162,16 @@ private[sql] object ArrowUtils { } } + def isVariantField(field: Field): Boolean = { + assert(field.getType.isInstanceOf[ArrowType.Struct]) + field.getChildren.asScala + .map(_.getName) + .asJava + .containsAll(Seq("value", "metadata").asJava) && field.getChildren.asScala.exists { child => + child.getName == "metadata" && child.getMetadata.getOrDefault("variant", "false") == "true" + } + } + def fromArrowField(field: Field): DataType = { field.getType match { case _: ArrowType.Map => @@ -170,12 +183,7 @@ private[sql] object ArrowUtils { val elementField = field.getChildren().get(0) val elementType = fromArrowField(elementField) ArrayType(elementType, containsNull = elementField.isNullable) - case ArrowType.Struct.INSTANCE - if field.getMetadata.getOrDefault("variant", "") == "true" - && field.getChildren.asScala - .map(_.getName) - .asJava - .containsAll(Seq("value", "metadata").asJava) => + case ArrowType.Struct.INSTANCE if isVariantField(field) => VariantType case ArrowType.Struct.INSTANCE => val fields = field.getChildren().asScala.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 4fa9d746f25d6..53273b29a7c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException.internalError import org.apache.spark.api.python.{PythonEvalType, PythonFunction} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -64,23 +63,6 @@ trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix" override def nullable: Boolean = true - - override def checkInputDataTypes(): TypeCheckResult = { - val check = super.checkInputDataTypes() - if (check.isFailure) { - check - } else { - val exprReturningVariant = children.collectFirst { - case e: Expression if VariantExpressionEvalUtils.typeContainsVariant(e.dataType) => e - } - exprReturningVariant match { - case Some(e) => TypeCheckResult.DataTypeMismatch( - errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE", - messageParameters = Map("dataType" -> s"${e.dataType.sql}")) - case None => TypeCheckResult.TypeCheckSuccess - } - } - } } /** @@ -97,10 +79,6 @@ case class PythonUDF( resultId: ExprId = NamedExpression.newExprId) extends Expression with PythonFuncExpression with Unevaluable { - if (VariantExpressionEvalUtils.typeContainsVariant(dataType)) { - throw QueryCompilationErrors.unsupportedUDFOuptutType(this, dataType) - } - lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( exprId = resultId) @@ -143,10 +121,6 @@ case class PythonUDAF( resultId: ExprId = NamedExpression.newExprId) extends UnevaluableAggregateFunc with PythonFuncExpression { - if (VariantExpressionEvalUtils.typeContainsVariant(dataType)) { - throw QueryCompilationErrors.unsupportedUDFOuptutType(this, dataType) - } - override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF override def sql(isDistinct: Boolean): String = { @@ -213,13 +187,6 @@ case class PythonUDTF( pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) extends UnevaluableGenerator with PythonFuncExpression { - elementSchema.collectFirst { - case sf: StructField if VariantExpressionEvalUtils.typeContainsVariant(sf.dataType) => sf - } match { - case Some(sf) => throw QueryCompilationErrors.unsupportedUDFOuptutType(this, sf.dataType) - case None => - } - override lazy val canonicalized: Expression = { val canonicalizedChildren = children.map(_.canonicalized) // `resultId` can be seen as cosmetic variation in PythonUDTF, as it doesn't affect the result. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 8332d82b7a7b8..1b2596e79ffec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3816,12 +3816,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "class" -> unsupported.getClass.toString)) } - def unsupportedUDFOuptutType(expr: Expression, dt: DataType): Throwable = { - new AnalysisException( - errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - messageParameters = Map("sqlExpr" -> toSQLExpr(expr), "dataType" -> dt.sql)) - } - def funcBuildError(funcName: String, cause: Exception): Throwable = { cause.getCause match { case st: SparkThrowable with Throwable => st diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index ca7703bef48bb..065b4b8c821a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -373,7 +373,7 @@ private[arrow] class StructWriter( val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { - lazy val isVariant = valueVector.getField.getMetadata.get("variant") == "true" + lazy val isVariant = ArrowUtils.isVariantField(valueVector.getField) override def setNull(): Unit = { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index fca277dae5d55..fd7ccb2189bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -31,12 +31,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} object EvaluatePython { def needConversionInPython(dt: DataType): Boolean = dt match { - case DateType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => true + case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType => true case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) @@ -201,6 +201,13 @@ object EvaluatePython { case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + case VariantType => (obj: Any) => nullSafeConvert(obj) { + case s: java.util.HashMap[_, _] => + new VariantVal( + s.get("value").asInstanceOf[Array[Byte]], s.get("metadata").asInstanceOf[Array[Byte]] + ) + } + case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 0339f7461f0a3..4b46331be107a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -86,62 +86,6 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(actual, expected) } - test("variant input to pandas grouped agg UDF") { - assume(shouldTestPandasUDFs) - val df = spark.range(0, 10).selectExpr( - """parse_json(format_string('{"%s": "test"}', id)) as v""") - - val testUdf = TestGroupedAggPandasUDFStringifiedMax(name = "pandas_udf") - checkError( - exception = intercept[AnalysisException] { - df.agg(testUdf(df("v"))).collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - parameters = Map("sqlExpr" -> "\"pandas_udf(v)\"", "dataType" -> "VARIANT")) - } - - test("complex variant input to pandas grouped agg UDF") { - assume(shouldTestPandasUDFs) - val df = spark.range(0, 10).selectExpr( - """array(parse_json(format_string('{"%s": "test"}', id))) as arr_v""") - - val testUdf = TestGroupedAggPandasUDFStringifiedMax(name = "pandas_udf") - checkError( - exception = intercept[AnalysisException] { - df.agg(testUdf(df("arr_v"))).collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - parameters = Map("sqlExpr" -> "\"pandas_udf(arr_v)\"", "dataType" -> "ARRAY")) - } - - test("variant output to pandas grouped agg UDF") { - assume(shouldTestPandasUDFs) - val df = spark.range(0, 10).toDF("id") - - val testUdf = TestGroupedAggPandasUDFReturnVariant(name = "pandas_udf") - checkError( - exception = intercept[AnalysisException] { - df.agg(testUdf(df("id"))).collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - parameters = Map("sqlExpr" -> "\"pandas_udf(id)\"", "dataType" -> "VARIANT")) - } - - test("complex variant output to pandas grouped agg UDF") { - assume(shouldTestPandasUDFs) - val df = spark.range(0, 10).toDF("id") - - val testUdf = TestGroupedAggPandasUDFReturnComplexVariant(name = "pandas_udf") - checkError( - exception = intercept[AnalysisException] { - df.agg(testUdf(df("id"))).collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"pandas_udf(id)\"", - "dataType" -> "STRUCT>")) - } - test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") { assume(shouldTestPythonUDFs) val pythonSQLMetrics = List( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 041bd143067a7..a6bf95be837da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Expression, FunctionTableSubqueryArgumentExpression, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias} @@ -124,84 +124,6 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Seq(Row(1, 2, -1), Row(1, 2, 1), Row(1, 2, 3))) } - test("Simple variant input UDTF") { - assume(shouldTestPythonUDFs) - withTempView("t") { - spark.udtf.registerPython("variantInputUDTF", variantInputUDTF) - spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v").createOrReplaceTempView("t") - checkError( - exception = intercept[AnalysisException] { - spark.sql("select udtf.* from t, lateral variantInputUDTF(v) udtf").collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> """"InputVariantUDTF\(outer\(v#\d+\)\)"""", - "dataType" -> "VARIANT"), - matchPVals = true, - queryContext = Array(ExpectedContext( - fragment = "variantInputUDTF(v) udtf", - start = 30, - stop = 53))) - } - } - - test("Complex variant input UDTF") { - assume(shouldTestPythonUDFs) - withTempView("t") { - spark.udtf.registerPython("variantInputUDTF", variantInputUDTF) - spark.range(0, 10) - .selectExpr("map(id, parse_json(cast(id as string))) map_v") - .createOrReplaceTempView("t") - checkError( - exception = intercept[AnalysisException] { - spark.sql("select udtf.* from t, lateral variantInputUDTF(map_v) udtf").collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> """"InputVariantUDTF\(outer\(map_v#\d+\)\)"""", - "dataType" -> "MAP"), - matchPVals = true, - queryContext = Array(ExpectedContext( - fragment = "variantInputUDTF(map_v) udtf", - start = 30, - stop = 57))) - } - } - - test("Simple variant output UDTF") { - assume(shouldTestPythonUDFs) - spark.udtf.registerPython("variantOutUDTF", variantOutputUDTF) - checkError( - exception = intercept[AnalysisException] { - spark.sql("select * from variantOutUDTF()").collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"SimpleOutputVariantUDTF()\"", - "dataType" -> "VARIANT"), - context = ExpectedContext( - fragment = "variantOutUDTF()", - start = 14, - stop = 29)) - } - - test("Complex variant output UDTF") { - assume(shouldTestPythonUDFs) - spark.udtf.registerPython("arrayOfVariantOutUDTF", arrayOfVariantOutputUDTF) - checkError( - exception = intercept[AnalysisException] { - spark.sql("select * from arrayOfVariantOutUDTF()").collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"OutputArrayOfVariantUDTF()\"", - "dataType" -> "ARRAY"), - context = ExpectedContext( - fragment = "arrayOfVariantOutUDTF()", - start = 14, - stop = 36)) - } - test("PythonUDTF with lateral join") { assume(shouldTestPythonUDFs) withTempView("t") { From 3c3d1a6a9f6daf6db5148f1423f49f4bce142858 Mon Sep 17 00:00:00 2001 From: cashmand Date: Wed, 13 Nov 2024 09:15:08 +0800 Subject: [PATCH 11/39] [SPARK-48898][SQL] Add Variant shredding functions ### What changes were proposed in this pull request? This is a first step towards adding Variant shredding support for the Parquet writer. It adds functionality to convert a Variant value to an InternalRow that matches the current shredding spec in https://github.com/apache/parquet-format/pull/461. Once this merges, the next step will be to set up the Parquet writer to accept a shredding schema, and write these InternalRow values to Parquet instead of the raw Variant binary. ### Why are the changes needed? First step towards adding support for shredding, which can improve Variant performance (and will be important for functionality on the read side once other tools begin writing shredded Variant columns to Parquet). ### Does this PR introduce _any_ user-facing change? No, none of this code is currently called outside of the added tests. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48779 from cashmand/SPARK-48898-write-shredding. Authored-by: cashmand Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 6 + .../apache/spark/types/variant/Variant.java | 12 + .../spark/types/variant/VariantBuilder.java | 29 +- .../spark/types/variant/VariantSchema.java | 153 +++++++++ .../types/variant/VariantShreddingWriter.java | 298 ++++++++++++++++++ .../spark/types/variant/VariantUtil.java | 8 +- .../sql/errors/QueryCompilationErrors.scala | 5 + .../parquet/SparkShreddingUtils.scala | 220 +++++++++++++ .../sql/VariantWriteShreddingSuite.scala | 218 +++++++++++++ 9 files changed, 942 insertions(+), 7 deletions(-) create mode 100644 common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java create mode 100644 common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7e7b84f2332b1..a8e60d1850e2e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3309,6 +3309,12 @@ ], "sqlState" : "22023" }, + "INVALID_VARIANT_SHREDDING_SCHEMA" : { + "message" : [ + "The schema `` is not a valid variant shredding schema." + ], + "sqlState" : "22023" + }, "INVALID_WHERE_CONDITION" : { "message" : [ "The WHERE condition contains invalid expressions: .", diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index 58a827847b257..a6fc6b534ee02 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -193,6 +193,18 @@ public ObjectField getFieldAtIndex(int index) { }); } + // Get the dictionary ID for the object field at the `index` slot. Throws malformedVariant if + // `index` is out of the bound of `[0, objectSize())`. + // It is only legal to call it when `getType()` is `Type.OBJECT`. + public int getDictionaryIdAtIndex(int index) { + return handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { + if (index < 0 || index >= size) { + throw malformedVariant(); + } + return readUnsigned(value, idStart + idSize * index, idSize); + }); + } + // Get the number of array elements in the variant. // It is only legal to call it when `getType()` is `Type.ARRAY`. public int arraySize() { diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index be9d380dd9e3b..32595baf6a4f2 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -107,6 +107,14 @@ public Variant result() { return new Variant(Arrays.copyOfRange(writeBuffer, 0, writePos), metadata); } + // Return the variant value only, without metadata. + // Used in shredding to produce a final value, where all shredded values refer to a common + // metadata. It is expected to be called instead of `result()`, although it is valid to call both + // methods, in any order. + public byte[] valueWithoutMetadata() { + return Arrays.copyOfRange(writeBuffer, 0, writePos); + } + public void appendString(String str) { byte[] text = str.getBytes(StandardCharsets.UTF_8); boolean longStr = text.length > MAX_SHORT_STR_SIZE; @@ -404,15 +412,26 @@ private void appendVariantImpl(byte[] value, byte[] metadata, int pos) { }); break; default: - int size = valueSize(value, pos); - checkIndex(pos + size - 1, value.length); - checkCapacity(size); - System.arraycopy(value, pos, writeBuffer, writePos, size); - writePos += size; + shallowAppendVariantImpl(value, pos); break; } } + // Append the variant value without rewriting or creating any metadata. This is used when + // building an object during shredding, where there is a fixed pre-existing metadata that + // all shredded values will refer to. + public void shallowAppendVariant(Variant v) { + shallowAppendVariantImpl(v.value, v.pos); + } + + private void shallowAppendVariantImpl(byte[] value, int pos) { + int size = valueSize(value, pos); + checkIndex(pos + size - 1, value.length); + checkCapacity(size); + System.arraycopy(value, pos, writeBuffer, writePos, size); + writePos += size; + } + private void checkCapacity(int additional) { int required = writePos + additional; if (required > writeBuffer.length) { diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java new file mode 100644 index 0000000000000..551e46214859a --- /dev/null +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.types.variant; + +import java.util.HashMap; +import java.util.Map; + +/** + * Defines a valid shredding schema, as described in + * https://github.com/apache/parquet-format/blob/master/VariantShredding.md. + * A shredding schema contains a value and optional typed_value field. + * If a typed_value is an array or struct, it recursively contain its own shredding schema for + * elements and fields, respectively. + * The schema also contains a metadata field at the top level, but not in recursively shredded + * fields. + */ +public class VariantSchema { + + // Represents one field of an object in the shredding schema. + public static final class ObjectField { + public final String fieldName; + public final VariantSchema schema; + + public ObjectField(String fieldName, VariantSchema schema) { + this.fieldName = fieldName; + this.schema = schema; + } + + @Override + public String toString() { + return "ObjectField{" + + "fieldName=" + fieldName + + ", schema=" + schema + + '}'; + } + } + + public abstract static class ScalarType { + } + + public static final class StringType extends ScalarType { + } + + public enum IntegralSize { + BYTE, SHORT, INT, LONG + } + + public static final class IntegralType extends ScalarType { + public final IntegralSize size; + + public IntegralType(IntegralSize size) { + this.size = size; + } + } + + public static final class FloatType extends ScalarType { + } + + public static final class DoubleType extends ScalarType { + } + + public static final class BooleanType extends ScalarType { + } + + public static final class BinaryType extends ScalarType { + } + + public static final class DecimalType extends ScalarType { + public final int precision; + public final int scale; + + public DecimalType(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + } + + public static final class DateType extends ScalarType { + } + + public static final class TimestampType extends ScalarType { + } + + public static final class TimestampNTZType extends ScalarType { + } + + // The index of the typed_value, value, and metadata fields in the schema, respectively. If a + // given field is not in the schema, its value must be set to -1 to indicate that it is invalid. + // The indices of valid fields should be contiguous and start from 0. + public final int typedIdx; + public final int variantIdx; + // topLevelMetadataIdx must be non-negative in the top-level schema, and -1 at all other nesting + // levels. + public final int topLevelMetadataIdx; + // The number of fields in the schema. I.e. a value between 1 and 3, depending on which of value, + // typed_value and metadata are present. + public final int numFields; + + public final ScalarType scalarSchema; + public final ObjectField[] objectSchema; + // Map for fast lookup of object fields by name. The values are an index into `objectSchema`. + public final Map objectSchemaMap; + public final VariantSchema arraySchema; + + public VariantSchema(int typedIdx, int variantIdx, int topLevelMetadataIdx, int numFields, + ScalarType scalarSchema, ObjectField[] objectSchema, + VariantSchema arraySchema) { + this.typedIdx = typedIdx; + this.numFields = numFields; + this.variantIdx = variantIdx; + this.topLevelMetadataIdx = topLevelMetadataIdx; + this.scalarSchema = scalarSchema; + this.objectSchema = objectSchema; + if (objectSchema != null) { + objectSchemaMap = new HashMap<>(); + for (int i = 0; i < objectSchema.length; i++) { + objectSchemaMap.put(objectSchema[i].fieldName, i); + } + } else { + objectSchemaMap = null; + } + + this.arraySchema = arraySchema; + } + + @Override + public String toString() { + return "VariantSchema{" + + "typedIdx=" + typedIdx + + ", variantIdx=" + variantIdx + + ", topLevelMetadataIdx=" + topLevelMetadataIdx + + ", numFields=" + numFields + + ", scalarSchema=" + scalarSchema + + ", objectSchema=" + objectSchema + + ", arraySchema=" + arraySchema + + '}'; + } +} diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java new file mode 100644 index 0000000000000..b5f8ea0a1484b --- /dev/null +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.types.variant; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; + +/** + * Class to implement shredding a Variant value. + */ +public class VariantShreddingWriter { + + // Interface to build up a shredded result. Callers should implement a ShreddedResultBuilder to + // create an empty result with a given schema. The castShredded method will call one or more of + // the add* methods to populate it. + public interface ShreddedResult { + // Create an array. The elements are the result of shredding each element. + void addArray(ShreddedResult[] array); + // Create an object. The values are the result of shredding each field, order by the index in + // objectSchema. Missing fields are populated with an empty result. + void addObject(ShreddedResult[] values); + void addVariantValue(byte[] result); + // Add a scalar to typed_value. The type of Object depends on the scalarSchema in the shredding + // schema. + void addScalar(Object result); + void addMetadata(byte[] result); + } + + public interface ShreddedResultBuilder { + ShreddedResult createEmpty(VariantSchema schema); + + // If true, we will shred decimals to a different scale or to integers, as long as they are + // numerically equivalent. Similarly, integers will be allowed to shred to decimals. + boolean allowNumericScaleChanges(); + } + + /** + * Converts an input variant into shredded components. Returns the shredded result, as well + * as the original Variant with shredded fields removed. + * `dataType` must be a valid shredding schema, as described in + * https://github.com/apache/parquet-format/blob/master/VariantShredding.md. + */ + public static ShreddedResult castShredded( + Variant v, + VariantSchema schema, + ShreddedResultBuilder builder) { + VariantUtil.Type variantType = v.getType(); + ShreddedResult result = builder.createEmpty(schema); + + if (schema.topLevelMetadataIdx >= 0) { + result.addMetadata(v.getMetadata()); + } + + if (schema.arraySchema != null && variantType == VariantUtil.Type.ARRAY) { + // The array element is always a struct containing untyped and typed fields. + VariantSchema elementSchema = schema.arraySchema; + int size = v.arraySize(); + ShreddedResult[] array = new ShreddedResult[size]; + for (int i = 0; i < size; ++i) { + ShreddedResult shreddedArray = castShredded(v.getElementAtIndex(i), elementSchema, builder); + array[i] = shreddedArray; + } + result.addArray(array); + } else if (schema.objectSchema != null && variantType == VariantUtil.Type.OBJECT) { + VariantSchema.ObjectField[] objectSchema = schema.objectSchema; + ShreddedResult[] shreddedValues = new ShreddedResult[objectSchema.length]; + + // Create a variantBuilder for any field that exist in `v`, but not in the shredding schema. + VariantBuilder variantBuilder = new VariantBuilder(false); + ArrayList fieldEntries = new ArrayList<>(); + // Keep track of which schema fields we actually found in the Variant value. + int numFieldsMatched = 0; + int start = variantBuilder.getWritePos(); + for (int i = 0; i < v.objectSize(); ++i) { + Variant.ObjectField field = v.getFieldAtIndex(i); + Integer fieldIdx = schema.objectSchemaMap.get(field.key); + if (fieldIdx != null) { + // The field exists in the shredding schema. Recursively shred, and write the result. + ShreddedResult shreddedField = castShredded( + field.value, objectSchema[fieldIdx].schema, builder); + shreddedValues[fieldIdx] = shreddedField; + numFieldsMatched++; + } else { + // The field is not shredded. Put it in the untyped_value column. + int id = v.getDictionaryIdAtIndex(i); + fieldEntries.add(new VariantBuilder.FieldEntry( + field.key, id, variantBuilder.getWritePos() - start)); + variantBuilder.appendVariant(field.value); + } + } + if (numFieldsMatched < objectSchema.length) { + // Set missing fields to non-null with all fields set to null. + for (int i = 0; i < objectSchema.length; ++i) { + if (shreddedValues[i] == null) { + VariantSchema.ObjectField fieldSchema = objectSchema[i]; + ShreddedResult emptyChild = builder.createEmpty(fieldSchema.schema); + shreddedValues[i] = emptyChild; + numFieldsMatched += 1; + } + } + } + if (numFieldsMatched != objectSchema.length) { + // Since we just filled in all the null entries, this can only happen if we tried to write + // to the same field twice; i.e. the Variant contained duplicate fields, which is invalid. + throw VariantUtil.malformedVariant(); + } + result.addObject(shreddedValues); + if (variantBuilder.getWritePos() != start) { + // We added something to the untyped value. + variantBuilder.finishWritingObject(start, fieldEntries); + result.addVariantValue(variantBuilder.valueWithoutMetadata()); + } + } else if (schema.scalarSchema != null) { + VariantSchema.ScalarType scalarType = schema.scalarSchema; + Object typedValue = tryTypedShred(v, variantType, scalarType, builder); + if (typedValue != null) { + // Store the typed value. + result.addScalar(typedValue); + } else { + VariantBuilder variantBuilder = new VariantBuilder(false); + variantBuilder.appendVariant(v); + result.addVariantValue(v.getValue()); + } + } else { + // Store in untyped. + result.addVariantValue(v.getValue()); + } + return result; + } + + /** + * Tries to cast a Variant into a typed value. If the cast fails, returns null. + * + * @param v + * @param variantType The Variant Type of v + * @param targetType The target type + * @return The scalar value, or null if the cast is not valid. + */ + private static Object tryTypedShred( + Variant v, + VariantUtil.Type variantType, + VariantSchema.ScalarType targetType, + ShreddedResultBuilder builder) { + switch (variantType) { + case LONG: + if (targetType instanceof VariantSchema.IntegralType integralType) { + // Check that the target type can hold the actual value. + VariantSchema.IntegralSize size = integralType.size; + long value = v.getLong(); + switch (size) { + case BYTE: + if (value == (byte) value) { + return (byte) value; + } + break; + case SHORT: + if (value == (short) value) { + return (short) value; + } + break; + case INT: + if (value == (int) value) { + return (int) value; + } + break; + case LONG: + return value; + } + } else if (targetType instanceof VariantSchema.DecimalType decimalType && + builder.allowNumericScaleChanges()) { + // If the integer can fit in the given decimal precision, allow it. + long value = v.getLong(); + // Set to the requested scale, and check if the precision is large enough. + BigDecimal decimalValue = BigDecimal.valueOf(value); + BigDecimal scaledValue = decimalValue.setScale(decimalType.scale); + // The initial value should have scale 0, so rescaling shouldn't lose information. + assert(decimalValue.compareTo(scaledValue) == 0); + if (scaledValue.precision() <= decimalType.precision) { + return scaledValue; + } + } + break; + case DECIMAL: + if (targetType instanceof VariantSchema.DecimalType decimalType) { + // Use getDecimalWithOriginalScale so that we retain scale information if + // allowNumericScaleChanges() is false. + BigDecimal value = VariantUtil.getDecimalWithOriginalScale(v.value, v.pos); + if (value.precision() <= decimalType.precision && + value.scale() == decimalType.scale) { + return value; + } + if (builder.allowNumericScaleChanges()) { + // Convert to the target scale, and see if it fits. Rounding mode doesn't matter, + // since we'll reject it if it turned out to require rounding. + BigDecimal scaledValue = value.setScale(decimalType.scale, RoundingMode.FLOOR); + if (scaledValue.compareTo(value) == 0 && + scaledValue.precision() <= decimalType.precision) { + return scaledValue; + } + } + } else if (targetType instanceof VariantSchema.IntegralType integralType && + builder.allowNumericScaleChanges()) { + // Check if the decimal happens to be an integer. + BigDecimal value = v.getDecimal(); + VariantSchema.IntegralSize size = integralType.size; + // Try to cast to the appropriate type, and check if any information is lost. + switch (size) { + case BYTE: + if (value.compareTo(BigDecimal.valueOf(value.byteValue())) == 0) { + return value.byteValue(); + } + break; + case SHORT: + if (value.compareTo(BigDecimal.valueOf(value.shortValue())) == 0) { + return value.shortValue(); + } + break; + case INT: + if (value.compareTo(BigDecimal.valueOf(value.intValue())) == 0) { + return value.intValue(); + } + break; + case LONG: + if (value.compareTo(BigDecimal.valueOf(value.longValue())) == 0) { + return value.longValue(); + } + } + } + break; + case BOOLEAN: + if (targetType instanceof VariantSchema.BooleanType) { + return v.getBoolean(); + } + break; + case STRING: + if (targetType instanceof VariantSchema.StringType) { + return v.getString(); + } + break; + case DOUBLE: + if (targetType instanceof VariantSchema.DoubleType) { + return v.getDouble(); + } + break; + case DATE: + if (targetType instanceof VariantSchema.DateType) { + return (int) v.getLong(); + } + break; + case TIMESTAMP: + if (targetType instanceof VariantSchema.TimestampType) { + return v.getLong(); + } + break; + case TIMESTAMP_NTZ: + if (targetType instanceof VariantSchema.TimestampNTZType) { + return v.getLong(); + } + break; + case FLOAT: + if (targetType instanceof VariantSchema.FloatType) { + return v.getFloat(); + } + break; + case BINARY: + if (targetType instanceof VariantSchema.BinaryType) { + return v.getBinary(); + } + break; + } + // The stored type does not match the requested shredding type. Return null, and the caller + // will store the result in untyped_value. + return null; + } + + // Add the result to the shredding result. + private static void addVariantValueVariant(Variant variantResult, + VariantSchema schema, ShreddedResult result) { + result.addVariantValue(variantResult.getValue()); + } + +} diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index f859a7be0c4b3..86609eef5d908 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -413,7 +413,7 @@ private static void checkDecimal(BigDecimal d, int maxPrecision) { // Get a decimal value from variant value `value[pos...]`. // Throw `MALFORMED_VARIANT` if the variant is malformed. - public static BigDecimal getDecimal(byte[] value, int pos) { + public static BigDecimal getDecimalWithOriginalScale(byte[] value, int pos) { checkIndex(pos, value.length); int basicType = value[pos] & BASIC_TYPE_MASK; int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; @@ -445,7 +445,11 @@ public static BigDecimal getDecimal(byte[] value, int pos) { default: throw unexpectedType(Type.DECIMAL); } - return result.stripTrailingZeros(); + return result; + } + + public static BigDecimal getDecimal(byte[] value, int pos) { + return getDecimalWithOriginalScale(value, pos).stripTrailingZeros(); } // Get a float value from variant value `value[pos...]`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1b2596e79ffec..f8a0286c2f941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1989,6 +1989,11 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("field" -> field)) } + def invalidVariantShreddingSchema(schema: DataType): Throwable = { + new AnalysisException(errorClass = "INVALID_VARIANT_SHREDDING_SCHEMA", + messageParameters = Map("schema" -> toSQLType(schema))) + } + def invalidVariantWrongNumFieldsError(): Throwable = { new AnalysisException(errorClass = "INVALID_VARIANT_FROM_PARQUET.WRONG_NUM_FIELDS", messageParameters = Map.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala new file mode 100644 index 0000000000000..2b81668b88b87 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types._ +import org.apache.spark.types.variant._ +import org.apache.spark.unsafe.types._ + +case object SparkShreddingUtils { + val VariantValueFieldName = "value"; + val TypedValueFieldName = "typed_value"; + val MetadataFieldName = "metadata"; + + def buildVariantSchema(schema: DataType): VariantSchema = { + schema match { + case s: StructType => buildVariantSchema(s, topLevel = true) + case _ => throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + } + + /** + * Given an expected schema of a Variant value, returns a suitable schema for shredding, by + * inserting appropriate intermediate value/typed_value fields at each level. + * For example, to represent the JSON {"a": 1, "b": "hello"}, + * the schema struct could be passed into this function, and it would return + * the shredding schema: + * struct< + * metadata: binary, + * value: binary, + * typed_value: struct< + * a: struct, + * b: struct>> + * + */ + def variantShreddingSchema(dataType: DataType, isTopLevel: Boolean = true): StructType = { + val fields = dataType match { + case ArrayType(elementType, containsNull) => + val arrayShreddingSchema = + ArrayType(variantShreddingSchema(elementType, false), containsNull) + Seq( + StructField(VariantValueFieldName, BinaryType, nullable = true), + StructField(TypedValueFieldName, arrayShreddingSchema, nullable = true) + ) + case StructType(fields) => + val objectShreddingSchema = StructType(fields.map(f => + f.copy(dataType = variantShreddingSchema(f.dataType, false)))) + Seq( + StructField(VariantValueFieldName, BinaryType, nullable = true), + StructField(TypedValueFieldName, objectShreddingSchema, nullable = true) + ) + case VariantType => + // For Variant, we don't need a typed column + Seq( + StructField(VariantValueFieldName, BinaryType, nullable = true) + ) + case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType => + Seq( + StructField(VariantValueFieldName, BinaryType, nullable = true), + StructField(TypedValueFieldName, dataType, nullable = true) + ) + case _ => + // No other types have a corresponding shreddings schema. + throw QueryCompilationErrors.invalidVariantShreddingSchema(dataType) + } + + if (isTopLevel) { + StructType(StructField(MetadataFieldName, BinaryType, nullable = false) +: fields) + } else { + StructType(fields) + } + } + + /* + * Given a Spark schema that represents a valid shredding schema (e.g. constructed by + * SparkShreddingUtils.variantShreddingSchema), return the corresponding VariantSchema. + */ + private def buildVariantSchema(schema: StructType, topLevel: Boolean): VariantSchema = { + var typedIdx = -1 + var variantIdx = -1 + var topLevelMetadataIdx = -1 + var scalarSchema: VariantSchema.ScalarType = null + var objectSchema: Array[VariantSchema.ObjectField] = null + var arraySchema: VariantSchema = null + + schema.fields.zipWithIndex.foreach { case (f, i) => + f.name match { + case TypedValueFieldName => + if (typedIdx != -1) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + typedIdx = i + f.dataType match { + case StructType(fields) => + objectSchema = + new Array[VariantSchema.ObjectField](fields.length) + fields.zipWithIndex.foreach { case (field, fieldIdx) => + field.dataType match { + case s: StructType => + val fieldSchema = buildVariantSchema(s, topLevel = false) + objectSchema(fieldIdx) = new VariantSchema.ObjectField(field.name, fieldSchema) + case _ => throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + } + case ArrayType(elementType, _) => + elementType match { + case s: StructType => arraySchema = buildVariantSchema(s, topLevel = false) + case _ => throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + case t => scalarSchema = (t match { + case BooleanType => new VariantSchema.BooleanType + case ByteType => new VariantSchema.IntegralType(VariantSchema.IntegralSize.BYTE) + case ShortType => new VariantSchema.IntegralType(VariantSchema.IntegralSize.SHORT) + case IntegerType => new VariantSchema.IntegralType(VariantSchema.IntegralSize.INT) + case LongType => new VariantSchema.IntegralType(VariantSchema.IntegralSize.LONG) + case FloatType => new VariantSchema.FloatType + case DoubleType => new VariantSchema.DoubleType + case StringType => new VariantSchema.StringType + case BinaryType => new VariantSchema.BinaryType + case DateType => new VariantSchema.DateType + case TimestampType => new VariantSchema.TimestampType + case TimestampNTZType => new VariantSchema.TimestampNTZType + case d: DecimalType => new VariantSchema.DecimalType(d.precision, d.scale) + case _ => throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + }) + } + case VariantValueFieldName => + if (variantIdx != -1 || f.dataType != BinaryType) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + variantIdx = i + case MetadataFieldName => + if (topLevelMetadataIdx != -1 || f.dataType != BinaryType) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + topLevelMetadataIdx = i + case _ => throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + } + + if (topLevel != (topLevelMetadataIdx >= 0)) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + new VariantSchema(typedIdx, variantIdx, topLevelMetadataIdx, schema.fields.length, + scalarSchema, objectSchema, arraySchema) + } + + class SparkShreddedResult(schema: VariantSchema) extends VariantShreddingWriter.ShreddedResult { + // Result is stored as an InternalRow. + val row = new GenericInternalRow(schema.numFields) + + override def addArray(array: Array[VariantShreddingWriter.ShreddedResult]): Unit = { + val arrayResult = new GenericArrayData( + array.map(_.asInstanceOf[SparkShreddedResult].row)) + row.update(schema.typedIdx, arrayResult) + } + + override def addObject(values: Array[VariantShreddingWriter.ShreddedResult]): Unit = { + val innerRow = new GenericInternalRow(schema.objectSchema.size) + for (i <- 0 until values.length) { + innerRow.update(i, values(i).asInstanceOf[SparkShreddedResult].row) + } + row.update(schema.typedIdx, innerRow) + } + + override def addVariantValue(result: Array[Byte]): Unit = { + row.update(schema.variantIdx, result) + } + + override def addScalar(result: Any): Unit = { + // Convert to native spark value, if necessary. + val sparkValue = schema.scalarSchema match { + case _: VariantSchema.StringType => UTF8String.fromString(result.asInstanceOf[String]) + case _: VariantSchema.DecimalType => Decimal(result.asInstanceOf[java.math.BigDecimal]) + case _ => result + } + row.update(schema.typedIdx, sparkValue) + } + + override def addMetadata(result: Array[Byte]): Unit = { + row.update(schema.topLevelMetadataIdx, result) + } + } + + class SparkShreddedResultBuilder() extends VariantShreddingWriter.ShreddedResultBuilder { + override def createEmpty(schema: VariantSchema): VariantShreddingWriter.ShreddedResult = { + new SparkShreddedResult(schema) + } + + // Consider allowing this to be set via config? + override def allowNumericScaleChanges(): Boolean = true + } + + /** + * Converts an input variant into shredded components. Returns the shredded result. + */ + def castShredded(v: Variant, schema: VariantSchema): InternalRow = { + VariantShreddingWriter.castShredded(v, schema, new SparkShreddedResultBuilder()) + .asInstanceOf[SparkShreddedResult] + .row + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala new file mode 100644 index 0000000000000..ed66ddb1f0f44 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils +import org.apache.spark.sql.execution.datasources.parquet.SparkShreddingUtils +import org.apache.spark.sql.types._ +import org.apache.spark.types.variant.Variant +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} + +class VariantWriteShreddingSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def parseJson(input: String): VariantVal = + VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input)) + + private def toVariant(input: Expression): VariantVal = { + Cast(input, VariantType, true).eval().asInstanceOf[VariantVal] + } + + private def untypedValue(input: String): Array[Byte] = { + val variantVal = parseJson(input) + variantVal.getValue + } + + private def untypedValue(input: VariantVal): Array[Byte] = input.getValue + + // Shreds variantVal with the requested schema, and verifies that the result is + // equal to `expected`. + private def testWithSchema(variantVal: VariantVal, + schema: DataType, expected: Row): Unit = { + val shreddingSchema = SparkShreddingUtils.variantShreddingSchema(schema) + val variant = new Variant(variantVal.getValue, variantVal.getMetadata) + val variantSchema = SparkShreddingUtils.buildVariantSchema(shreddingSchema) + val actual = SparkShreddingUtils.castShredded(variant, variantSchema) + + val catalystExpected = CatalystTypeConverters.convertToCatalyst(expected) + if (!checkResult(actual, catalystExpected, shreddingSchema, exprNullable = false)) { + fail(s"Incorrect evaluation of castShredded: " + + s"actual: $actual, " + + s"expected: $expected") + } + } + + // Parse the provided JSON into a Variant, shred it to the provided schema, and verify the result. + private def testWithSchema(input: String, dataType: DataType, expected: Row): Unit = { + val variantVal = parseJson(input) + testWithSchema(variantVal, dataType, expected) + } + + private val emptyMetadata: Array[Byte] = parseJson("null").getMetadata + + test("shredding as fixed numeric types") { + /* Cast integer to any wider numeric type. */ + testWithSchema("1", IntegerType, Row(emptyMetadata, null, 1)) + testWithSchema("1", LongType, Row(emptyMetadata, null, 1)) + testWithSchema("1", ShortType, Row(emptyMetadata, null, 1)) + testWithSchema("1", ByteType, Row(emptyMetadata, null, 1)) + + // Invalid casts + Seq(StringType, DecimalType(5, 5), TimestampType, DateType, BooleanType, DoubleType, FloatType, + BinaryType, ArrayType(IntegerType), + StructType.fromDDL("a int, b int")).foreach { t => + testWithSchema("1", t, Row(emptyMetadata, untypedValue("1"), null)) + } + + /* Test conversions between numeric types and scales. */ + testWithSchema("1", DecimalType(5, 2), Row(emptyMetadata, null, Decimal("1"))) + testWithSchema("1", DecimalType(38, 37), Row(emptyMetadata, null, Decimal("1"))) + // Decimals that are effectively storing integers can also be cast to integer. + testWithSchema("1.0", IntegerType, Row(emptyMetadata, null, 1)) + testWithSchema("1.0000000000000000000000000000000000000", IntegerType, + Row(emptyMetadata, null, 1)) + // Don't overflow the integer type when converting from decimal. + testWithSchema("32767.0", ShortType, Row(emptyMetadata, null, 32767)) + testWithSchema("32768.0", ShortType, Row(emptyMetadata, untypedValue("32768.0"), null)) + // Don't overflow decimal type when converting from integer. + testWithSchema("99999", DecimalType(7, 2), Row(emptyMetadata, null, Decimal("99999.00"))) + testWithSchema("100000", DecimalType(7, 2), Row(emptyMetadata, untypedValue("100000"), null)) + // Allow scale to increase + testWithSchema("12.34", DecimalType(7, 4), Row(emptyMetadata, null, Decimal("12.3400"))) + // Allow scale to decrease if there are trailing zeros + testWithSchema("12.3400", DecimalType(4, 2), Row(emptyMetadata, null, Decimal("12.34"))) + testWithSchema("12.3410", DecimalType(4, 2), Row(emptyMetadata, untypedValue("12.3410"), null)) + + // The string 1 is not numeric + testWithSchema("\"1\"", IntegerType, Row(emptyMetadata, untypedValue("\"1\""), null)) + // Decimal would lose information. + testWithSchema("1.1", IntegerType, Row(emptyMetadata, untypedValue("1.1"), null)) + // Exponential notation is parsed as double, cannot be shredded to other numeric types. + testWithSchema("1e2", IntegerType, Row(emptyMetadata, untypedValue("1e2"), null)) + // Null is not an integer + testWithSchema("null", IntegerType, Row(emptyMetadata, untypedValue("null"), null)) + + // Overflow leads to storing as unshredded. + testWithSchema("32767", ShortType, Row(emptyMetadata, null, 32767)) + testWithSchema("32768", ShortType, Row(emptyMetadata, untypedValue("32768"), null)) + + testWithSchema("1e2", DoubleType, Row(emptyMetadata, null, 1e2)) + // We currently don't allow shredding double as float. + testWithSchema("1e2", FloatType, Row(emptyMetadata, untypedValue("1e2"), null)) + } + + test("shredding as other scalar types") { + // Test types that aren't produced by parseJson + val floatV = toVariant(Literal(1.2f, FloatType)) + testWithSchema(floatV, FloatType, Row(emptyMetadata, null, 1.2f)) + testWithSchema(floatV, DoubleType, Row(emptyMetadata, untypedValue(floatV), null)) + + val booleanV = toVariant(Literal(true, BooleanType)) + testWithSchema(booleanV, BooleanType, Row(emptyMetadata, null, true)) + testWithSchema(booleanV, StringType, Row(emptyMetadata, untypedValue(booleanV), null)) + + val binaryV = toVariant(Literal(Array[Byte](-1, -2), BinaryType)) + testWithSchema(binaryV, BinaryType, Row(emptyMetadata, null, Array[Byte](-1, -2))) + testWithSchema(binaryV, StringType, Row(emptyMetadata, untypedValue(binaryV), null)) + + val dateV = toVariant(Literal(0, DateType)) + testWithSchema(dateV, DateType, Row(emptyMetadata, null, 0)) + testWithSchema(dateV, TimestampType, Row(emptyMetadata, untypedValue(dateV), null)) + + val timestampV = toVariant(Literal(0L, TimestampType)) + testWithSchema(timestampV, TimestampType, Row(emptyMetadata, null, 0)) + testWithSchema(timestampV, TimestampNTZType, Row(emptyMetadata, untypedValue(timestampV), null)) + + val timestampNtzV = toVariant(Literal(0L, TimestampNTZType)) + testWithSchema(timestampNtzV, TimestampNTZType, Row(emptyMetadata, null, 0)) + testWithSchema(timestampNtzV, TimestampType, + Row(emptyMetadata, untypedValue(timestampNtzV), null)) + } + + test("shredding as object") { + val obj = parseJson("""{"a": 1, "b": "hello"}""") + // Can't be cast to scalar or array. + Seq(IntegerType, LongType, ShortType, ByteType, StringType, DecimalType(5, 5), + TimestampType, DateType, BooleanType, DoubleType, FloatType, BinaryType, + ArrayType(IntegerType)).foreach { t => + testWithSchema(obj, t, Row(obj.getMetadata, untypedValue(obj), null)) + } + + // Happy path + testWithSchema(obj, StructType.fromDDL("a int, b string"), + Row(obj.getMetadata, null, Row(Row(null, 1), Row(null, "hello")))) + // Missing field. + testWithSchema(obj, StructType.fromDDL("a int, c string, b string"), + Row(obj.getMetadata, null, Row(Row(null, 1), Row(null, null), Row(null, "hello")))) + // "a" is not present in shredding schema. + testWithSchema(obj, StructType.fromDDL("b string, c string"), + Row(obj.getMetadata, untypedValue("""{"a": 1}"""), Row(Row(null, "hello"), Row(null, null)))) + // "b" is not present in shredding schema. This case is a bit trickier, because the ID + // will be 1, not 0, since we'll use the original metadata dictionary that contains a and b. + // So we need to edit the variant value produced by parseJson. + val residual = untypedValue("""{"b": "hello"}""") + // First byte is the type, second is number of fields, and the third is the + // dictionary ID of the first field. + residual(2) = 1 + testWithSchema(obj, StructType.fromDDL("a int, c string"), + Row(obj.getMetadata, residual, Row(Row(null, 1), Row(null, null)))) + // "a" is the wrong type. + testWithSchema(obj, StructType.fromDDL("a string, b string"), + Row(obj.getMetadata, null, Row(Row(untypedValue("1"), null), Row(null, "hello")))) + // Not an object + testWithSchema(obj, ArrayType(StructType.fromDDL("a int, b string")), + Row(obj.getMetadata, untypedValue(obj), null)) + } + + test("shredding as array") { + val arr = parseJson("""[{"a": 1, "b": "hello"}, 2, null, 4]""") + // Can't be cast to scalar or object. + Seq(IntegerType, LongType, ShortType, ByteType, StringType, DecimalType(5, 5), + TimestampType, DateType, BooleanType, DoubleType, FloatType, BinaryType, + StructType.fromDDL("a int, b string")).foreach { t => + testWithSchema(arr, t, Row(arr.getMetadata, untypedValue(arr), null)) + } + // First element is shredded + testWithSchema(arr, ArrayType(StructType.fromDDL("a int, b string")), + Row(arr.getMetadata, null, Array( + Row(null, Row(Row(null, 1), Row(null, "hello"))), + Row(untypedValue("2"), null), + Row(untypedValue("null"), null), + Row(untypedValue("4"), null) + ))) + // Second and fourth are shredded + testWithSchema(arr, ArrayType(LongType), + Row(arr.getMetadata, null, Array( + Row(untypedValue("""{"a": 1, "b": "hello"}"""), null), + Row(null, 2), + Row(untypedValue("null"), null), + Row(null, 4) + ))) + + // Fully shredded + testWithSchema("[1,2,3]", ArrayType(LongType), + Row(emptyMetadata, null, Array( + Row(null, 1), + Row(null, 2), + Row(null, 3) + ))) + } + +} From 1198117a82abce6c50629ce1a325b748f0667b39 Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Wed, 13 Nov 2024 09:45:39 +0800 Subject: [PATCH 12/39] [SPARK-50130][SQL][PYTHON][TEST] Add more test for outer reference ### What changes were proposed in this pull request? Adds more test for outer reference. ### Why are the changes needed? One test is missing. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added the test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48828 from ueshin/issues/SPARK-50130/test. Authored-by: Takuya Ueshin Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/tests/test_subquery.py | 20 +++++++++++++++++++ .../spark/sql/DataFrameSubquerySuite.scala | 15 ++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/pyspark/sql/tests/test_subquery.py b/python/pyspark/sql/tests/test_subquery.py index 7d50d0959c215..f58ff6364aed7 100644 --- a/python/pyspark/sql/tests/test_subquery.py +++ b/python/pyspark/sql/tests/test_subquery.py @@ -470,6 +470,26 @@ def test_scalar_subquery_with_outer_reference_errors(self): fragment="outer", ) + with self.subTest("missing `outer()` for another outer"): + with self.assertRaises(AnalysisException) as pe: + self.spark.table("l").select( + "a", + ( + self.spark.table("r") + .where(sf.col("b") == sf.col("a").outer()) + .select(sf.sum("d")) + .scalar() + ), + ).collect() + + self.check_error( + exception=pe.exception, + errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters={"objectName": "`b`", "proposal": "`c`, `d`"}, + query_context_type=QueryContextType.DataFrame, + fragment="col", + ) + class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index fd31efb3054b1..5a065d7e73b1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -363,5 +363,20 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { queryContext = Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern)) ) + + // Missing `outer()` for another outer + val exception3 = intercept[AnalysisException] { + spark.table("l").select( + $"a", + spark.table("r").where($"b" === $"a".outer()).select(sum($"d")).scalar() + ).collect() + } + checkError( + exception3, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`b`", "proposal" -> "`c`, `d`"), + queryContext = + Array(ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) + ) } } From 26330355836f5b2dad9b7bd4c72d9830c7ce6788 Mon Sep 17 00:00:00 2001 From: Paddy Xu Date: Wed, 13 Nov 2024 10:51:02 +0900 Subject: [PATCH 13/39] [SPARK-49249][SPARK-49122] Artifact isolation in Spark Classic ### What changes were proposed in this pull request? This PR makes the isolation feature introduced by `SparkSession.addArtifact` API (added in https://github.com/apache/spark/pull/47631) work with Spark SQL. Note that this PR does not enable isolation for the following two use cases: - PySpark - Future work is needed to add API to support adding isolated Python UDTFs. - When Hive is used as the metastore - Hive UDF is a huge blocker due to artifacts can be used outside a `SparkSession`, which resources escaped from our session scope. ### Why are the changes needed? Because it didn't work before :) ### Does this PR introduce _any_ user-facing change? Yes, the user can add a new artifact in the REPL and use it in the current REPL session. ### How was this patch tested? Added a new test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48120 from xupefei/session-artifact-apply. Authored-by: Paddy Xu Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/UDFRegistration.scala | 6 + .../scala/org/apache/spark/SparkFiles.scala | 8 +- python/pyspark/core/context.py | 2 + python/pyspark/sql/connect/session.py | 6 +- .../scala/org/apache/spark/repl/Main.scala | 4 + repl/src/test/resources/IntSumUdf.class | Bin 0 -> 1333 bytes repl/src/test/resources/IntSumUdf.scala | 22 +++ .../org/apache/spark/repl/ReplSuite.scala | 63 +++++++ .../spark/sql/api/UDFRegistration.scala | 17 ++ .../apache/spark/sql/internal/SQLConf.scala | 22 +++ .../connect/SimpleSparkConnectService.scala | 3 + .../connect/service/SparkConnectServer.scala | 7 +- .../org/apache/spark/sql/SparkSession.scala | 4 + .../apache/spark/sql/UDFRegistration.scala | 18 +- .../spark/sql/artifact/ArtifactManager.scala | 74 +++++--- .../spark/sql/execution/SQLExecution.scala | 173 +++++++++--------- .../execution/streaming/StreamExecution.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../sql/artifact/ArtifactManagerSuite.scala | 27 ++- .../sql/execution/command/DDLSuite.scala | 7 +- .../sql/hive/execution/HiveQuerySuite.scala | 8 + .../apache/spark/sql/hive/test/TestHive.scala | 6 +- .../sql/hive/test/TestHiveSingleton.scala | 7 + 23 files changed, 344 insertions(+), 146 deletions(-) create mode 100644 repl/src/test/resources/IntSumUdf.class create mode 100644 repl/src/test/resources/IntSumUdf.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3a84d43ceae3b..93d085a25c7b5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.UdfToProtoUtils +import org.apache.spark.sql.types.DataType /** * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: @@ -30,6 +31,11 @@ import org.apache.spark.sql.internal.UdfToProtoUtils * @since 3.5.0 */ class UDFRegistration(session: SparkSession) extends api.UDFRegistration { + override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + throw new UnsupportedOperationException( + "registerJava is currently not supported in Spark Connect.") + } + override protected def register( name: String, udf: UserDefinedFunction, diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.scala b/core/src/main/scala/org/apache/spark/SparkFiles.scala index 44f4444a1fa8d..f4165c2fc6f28 100644 --- a/core/src/main/scala/org/apache/spark/SparkFiles.scala +++ b/core/src/main/scala/org/apache/spark/SparkFiles.scala @@ -27,8 +27,12 @@ object SparkFiles { /** * Get the absolute path of a file added through `SparkContext.addFile()`. */ - def get(filename: String): String = - new File(getRootDirectory(), filename).getAbsolutePath() + def get(filename: String): String = { + val jobArtifactUUID = JobArtifactSet + .getCurrentJobArtifactState.map(_.uuid).getOrElse("default") + val withUuid = if (jobArtifactUUID == "default") filename else s"$jobArtifactUUID/$filename" + new File(getRootDirectory(), withUuid).getAbsolutePath + } /** * Get the root directory that contains files added through `SparkContext.addFile()`. diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py index 63d41c11dafda..6ea793a118389 100644 --- a/python/pyspark/core/context.py +++ b/python/pyspark/core/context.py @@ -84,6 +84,8 @@ DEFAULT_CONFIGS: Dict[str, Any] = { "spark.serializer.objectStreamReset": 100, "spark.rdd.compress": True, + # Disable artifact isolation in PySpark, or user-added .py file won't work + "spark.sql.artifact.isolation.enabled": "false", } T = TypeVar("T") diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index e9984fae9ddba..83b0496a84274 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -1037,7 +1037,11 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: os.environ["SPARK_LOCAL_CONNECT"] = "1" # Configurations to be set if unset. - default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"} + default_conf = { + "spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin", + "spark.sql.artifact.isolation.enabled": "true", + "spark.sql.artifact.isolation.always.apply.classloader": "true", + } if "SPARK_TESTING" in os.environ: # For testing, we use 0 to use an ephemeral port to allow parallel testing. diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index 7b126c357271b..4d3465b320391 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -25,6 +25,7 @@ import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.Utils @@ -95,6 +96,9 @@ object Main extends Logging { // initialization in certain cases, there's an initialization order issue that prevents // this from being set after SparkContext is instantiated. conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + // Disable isolation for REPL, to avoid having in-line classes stored in a isolated directory, + // prevent the REPL classloader from finding it. + conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, false) if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/src/test/resources/IntSumUdf.class b/repl/src/test/resources/IntSumUdf.class new file mode 100644 index 0000000000000000000000000000000000000000..75a41446cfca1f93570ab90a74d80b51e0fb59e4 GIT binary patch literal 1333 zcma)5U2oG!6usBZIJJ`|cG>`GyR;4klD5>eD$o)Z$qA$+N(&)r5uef|aZ{JXCXNGT zrTs}ryDxzDVSm8>sMTJ3LsijA@Gx^{&fI&>o#XL8|Niwi0Bpcb1i9k)jZuHEa}WhY zaB4^VJvVyX=^Bq)M^@V})LgG?@OQj@Xc@zI+?mFBbc@_z>)<@qOBBmw$ zb=~ABVp7l78J2KI&ubEv^gX$vRp$!Lx3^cyJ2R3HV_RnK#hrnD;BTO*L{LYeWU}Ne ztLsHmniMsDap%=WOuQwV=F7DYrEA)tC<@VerTMI%DQK~Ww%2*2Cd5iz*B>Y`R?nxn zxi)AlX}n#=NVnX%Uc|{6F`bz(GIx7vxxT)LkrtQcqN}98#WI84*Bc7^sAaLZ{BW`= zWU-P;i0r9mR#56D;@hjwnOI>d6PugvV` zt60alU<$v>I=;!sv8@zlOjTg@X|g^-hV{s%T(z37A_)FkFs;*pe`pU8Vzm>`$xWn` z&D_s?{`}b_1o3#0=xHYYnem4kIlkSujG}vN4WEvjCT*f{frLj+%RO$oHP`7Pq|Xt} zSL1LAG74OTnK;B@QUL+x=#8JMo61#BcM^~l;&2(}gQ^9D)Ol@6_iV@Z?^CC_{C*s+ zz@h>Su97Duw`0*Wtl5sWJLhFI{!pbfE>*dF^ zV3hz>HHm&|OutLyLAkoL{1s+bXnclWzJmS-VW*|0Pf4jKQYuVCHk4ASnJiU|rKTw* zO#=hhX%22;Pz<_eOJCsHpXc(&|4$h6GguC@L|6&q1l*t)qa2xrGNk1y0aHXOhB~2b L;pIq-K=SQBYqk#E literal 0 HcmV?d00001 diff --git a/repl/src/test/resources/IntSumUdf.scala b/repl/src/test/resources/IntSumUdf.scala new file mode 100644 index 0000000000000..9678caaed5db5 --- /dev/null +++ b/repl/src/test/resources/IntSumUdf.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.sql.api.java.UDF2 + +class IntSumUdf extends UDF2[Long, Long, Long] { + override def call(t1: Long, t2: Long): Long = t1 + t2 +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 1a7be083d2d92..327ef3d074207 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -396,4 +396,67 @@ class ReplSuite extends SparkFunSuite { Main.sparkContext.stop() System.clearProperty("spark.driver.port") } + + test("register UDF via SparkSession.addArtifact") { + val artifactPath = new File("src/test/resources").toPath + val intSumUdfPath = artifactPath.resolve("IntSumUdf.class") + val output = runInterpreterInPasteMode("local", + s""" + |import org.apache.spark.sql.api.java.UDF2 + |import org.apache.spark.sql.types.DataTypes + | + |spark.addArtifact("${intSumUdfPath.toString}") + | + |spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType) + | + |val r = spark.range(5) + | .withColumn("id2", col("id") + 1) + | .selectExpr("intSum(id, id2)") + | .collect() + |assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9)) + | + """.stripMargin) + assertContains("Array([1], [3], [5], [7], [9])", output) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertDoesNotContain("assertion failed", output) + + // The UDF should not work in a new REPL session. + val anotherOutput = runInterpreterInPasteMode("local", + s""" + |val r = spark.range(5) + | .withColumn("id2", col("id") + 1) + | .selectExpr("intSum(id, id2)") + | .collect() + | + """.stripMargin) + assertContains( + "[UNRESOLVED_ROUTINE] Cannot resolve routine `intSum` on search path", + anotherOutput) + } + + test("register a class via SparkSession.addArtifact") { + val artifactPath = new File("src/test/resources").toPath + val intSumUdfPath = artifactPath.resolve("IntSumUdf.class") + val output = runInterpreterInPasteMode("local", + s""" + |import org.apache.spark.sql.functions.udf + | + |spark.addArtifact("${intSumUdfPath.toString}") + | + |val intSumUdf = udf((x: Long, y: Long) => new IntSumUdf().call(x, y)) + |spark.udf.register("intSum", intSumUdf) + | + |val r = spark.range(5) + | .withColumn("id2", col("id") + 1) + | .selectExpr("intSum(id, id2)") + | .collect() + |assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9)) + | + """.stripMargin) + assertContains("Array([1], [3], [5], [7], [9])", output) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertDoesNotContain("assertion failed", output) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala index c11e266827ff9..a8e8f5c5f8556 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala @@ -35,6 +35,23 @@ import org.apache.spark.sql.types.DataType */ abstract class UDFRegistration { + /** + * Register a Java UDF class using it's class name. The class must implement one of the UDF + * interfaces in the [[org.apache.spark.sql.api.java]] package, and discoverable by the current + * session's class loader. + * + * @param name + * Name of the UDF. + * @param className + * Fully qualified class name of the UDF. + * @param returnDataType + * Return type of UDF. If it is `null`, Spark would try to infer via reflection. + * @note + * this method is currently not supported in Spark Connect. + * @since 4.0.0 + */ + def registerJava(name: String, className: String, returnDataType: DataType): Unit + /** * Registers a user-defined function (UDF), for a UDF that's already defined using the Dataset * API (i.e. of type UserDefinedFunction). To change a UDF to nondeterministic, call the API diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d17ab656fe6b6..eac89212b9da8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3957,6 +3957,28 @@ object SQLConf { .intConf .createWithDefault(20) + val ARTIFACTS_SESSION_ISOLATION_ENABLED = + buildConf("spark.sql.artifact.isolation.enabled") + .internal() + .doc("When enabled for a Spark Session, artifacts (such as JARs, files, archives) added to " + + "this session are isolated from other sessions within the same Spark instance. When " + + "disabled for a session, artifacts added to this session are visible to other sessions " + + "that have this config disabled. This config can only be set during the creation of a " + + "Spark Session and will have no effect when changed in the middle of session usage.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + + val ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER = + buildConf("spark.sql.artifact.isolation.always.apply.classloader") + .internal() + .doc("When enabled, the classloader holding per-session artifacts will always be applied " + + "during SQL executions (useful for Spark Connect). When disabled, the classloader will " + + "be applied only when any artifact is added to the session.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT = buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit") .internal() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala index 1b6bdd8cd9393..8061e913dc0da 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala @@ -25,6 +25,7 @@ import scala.sys.exit import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.internal.SQLConf /** * A simple main class method to start the spark connect server as a service for client tests @@ -40,6 +41,8 @@ private[sql] object SimpleSparkConnectService { def main(args: Array[String]): Unit = { val conf = new SparkConf() .set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin") + .set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, true) + .set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER, true) val sparkSession = SparkSession.builder().config(conf).getOrCreate() val sparkContext = sparkSession.sparkContext // init spark context // scalastyle:off println diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala index 4f05ea927e12b..b2c4d1abb17b4 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{HOST, PORT} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf /** * The Spark Connect server @@ -28,7 +29,11 @@ object SparkConnectServer extends Logging { def main(args: Array[String]): Unit = { // Set the active Spark Session, and starts SparkEnv instance (via Spark Context) logInfo("Starting Spark session.") - val session = SparkSession.builder().getOrCreate() + val session = SparkSession + .builder() + .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, true) + .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key, true) + .getOrCreate() try { try { SparkConnectService.start(session.sparkContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3af4a26cf1876..afc0a2d7df604 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -789,7 +789,11 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { /** @inheritdoc */ override def enableHiveSupport(): this.type = synchronized { if (hiveClassesArePresent) { + // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR` command. This will + // break an existing Hive use case (one session adds JARs and another session uses them). + // We need to decide whether/how to enable isolation for Hive. super.enableHiveSupport() + .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, false) } else { throw new IllegalArgumentException( "Unable to instantiate SparkSession with Hive support because " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 2724399a1a84c..6715673cf3d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: @@ -44,7 +43,7 @@ import org.apache.spark.util.Utils * @since 1.3.0 */ @Stable -class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) +class UDFRegistration private[sql] (session: SparkSession, functionRegistry: FunctionRegistry) extends api.UDFRegistration with Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { @@ -121,7 +120,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) */ private[sql] def registerJavaUDAF(name: String, className: String): Unit = { try { - val clazz = Utils.classForName[AnyRef](className) + val clazz = session.artifactManager.classloader.loadClass(className) if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { throw QueryCompilationErrors .classDoesNotImplementUserDefinedAggregateFunctionError(className) @@ -137,17 +136,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) } // scalastyle:off line.size.limit - /** - * Register a Java UDF class using reflection, for use from pyspark - * - * @param name udf name - * @param className fully qualified class name of udf - * @param returnDataType return type of udf. If it is null, spark would try to infer - * via reflection. - */ - private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + + override def registerJava(name: String, className: String, returnDataType: DataType): Unit = { try { - val clazz = Utils.classForName[AnyRef](className) + val clazz = session.artifactManager.classloader.loadClass(className) val udfInterfaces = clazz.getGenericInterfaces .filter(_.isInstanceOf[ParameterizedType]) .map(_.asInstanceOf[ParameterizedType]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala index b81c369f7e9c6..d362c5bef878e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala @@ -22,6 +22,7 @@ import java.net.{URI, URL, URLClassLoader} import java.nio.ByteBuffer import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption} import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.AtomicBoolean import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -68,18 +69,43 @@ class ArtifactManager(session: SparkSession) extends Logging { s"$artifactRootURI${File.separator}${session.sessionUUID}") // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. - protected[artifact] val (classDir, classURI): (Path, String) = + protected[artifact] val (classDir, replClassURI): (Path, String) = (ArtifactUtils.concatenatePaths(artifactPath, "classes"), s"$artifactURI${File.separator}classes${File.separator}") - protected[artifact] val state: JobArtifactState = - JobArtifactState(session.sessionUUID, Option(classURI)) + private lazy val alwaysApplyClassLoader = + session.conf.get(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key).toBoolean - def withResources[T](f: => T): T = { - Utils.withContextClassLoader(classloader) { - JobArtifactSet.withActiveJobArtifactState(state) { + private lazy val sessionIsolated = + session.conf.get(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key).toBoolean + + protected[sql] lazy val state: JobArtifactState = + if (sessionIsolated) JobArtifactState(session.sessionUUID, Some(replClassURI)) else null + + /** + * Whether any artifact has been added to this artifact manager. We use this to determine whether + * we should apply the classloader to the session, see `withClassLoaderIfNeeded`. + */ + protected val sessionArtifactAdded = new AtomicBoolean(false) + + private def withClassLoaderIfNeeded[T](f: => T): T = { + val log = s" classloader for session ${session.sessionUUID} because " + + s"alwaysApplyClassLoader=$alwaysApplyClassLoader, " + + s"sessionArtifactAdded=${sessionArtifactAdded.get()}." + if (alwaysApplyClassLoader || sessionArtifactAdded.get()) { + logDebug(s"Applying $log") + Utils.withContextClassLoader(classloader) { f } + } else { + logDebug(s"Not applying $log") + f + } + } + + def withResources[T](f: => T): T = withClassLoaderIfNeeded { + JobArtifactSet.withActiveJobArtifactState(state) { + f } } @@ -176,6 +202,7 @@ class ArtifactManager(session: SparkSession) extends Logging { target, allowOverwrite = true, deleteSource = deleteStagedFile) + sessionArtifactAdded.set(true) } else { val target = ArtifactUtils.concatenatePaths(artifactPath, normalizedRemoteRelativePath) // Disallow overwriting with modified version @@ -199,6 +226,7 @@ class ArtifactManager(session: SparkSession) extends Logging { sparkContextRelativePaths.add( (SparkContextResourceType.JAR, normalizedRemoteRelativePath, fragment)) jarsList.add(normalizedRemoteRelativePath) + sessionArtifactAdded.set(true) } else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) { session.sparkContext.addFile(uri) sparkContextRelativePaths.add( @@ -258,9 +286,10 @@ class ArtifactManager(session: SparkSession) extends Logging { * Returns a [[ClassLoader]] for session-specific jar/class file resources. */ def classloader: ClassLoader = { - val urls = getAddedJars :+ classDir.toUri.toURL + val urls = (getAddedJars :+ classDir.toUri.toURL).toArray val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES) val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) + val fallbackClassLoader = session.sharedState.jarClassLoader val loader = if (prefixes.nonEmpty) { // Two things you need to know about classloader for all of this to make sense: // 1. A classloader needs to be able to fully define a class. @@ -274,21 +303,16 @@ class ArtifactManager(session: SparkSession) extends Logging { // it delegates to. if (userClasspathFirst) { // USER -> SYSTEM -> STUB - new ChildFirstURLClassLoader( - urls.toArray, - StubClassLoader(Utils.getContextOrSparkClassLoader, prefixes)) + new ChildFirstURLClassLoader(urls, StubClassLoader(fallbackClassLoader, prefixes)) } else { // SYSTEM -> USER -> STUB - new ChildFirstURLClassLoader( - urls.toArray, - StubClassLoader(null, prefixes), - Utils.getContextOrSparkClassLoader) + new ChildFirstURLClassLoader(urls, StubClassLoader(null, prefixes), fallbackClassLoader) } } else { if (userClasspathFirst) { - new ChildFirstURLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + new ChildFirstURLClassLoader(urls, fallbackClassLoader) } else { - new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + new URLClassLoader(urls, fallbackClassLoader) } } @@ -347,14 +371,16 @@ class ArtifactManager(session: SparkSession) extends Logging { // Clean up added files val fileserver = SparkEnv.get.rpcEnv.fileServer val sparkContext = session.sparkContext - val shouldUpdateEnv = sparkContext.addedFiles.contains(state.uuid) || - sparkContext.addedArchives.contains(state.uuid) || - sparkContext.addedJars.contains(state.uuid) - if (shouldUpdateEnv) { - sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) - sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) - sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar)) - sparkContext.postEnvironmentUpdate() + if (state != null) { + val shouldUpdateEnv = sparkContext.addedFiles.contains(state.uuid) || + sparkContext.addedArchives.contains(state.uuid) || + sparkContext.addedJars.contains(state.uuid) + if (shouldUpdateEnv) { + sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) + sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) + sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar)) + sparkContext.postEnvironmentUpdate() + } } // Clean up cached relations diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 5db14a8662138..e805aabe013cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -120,93 +120,97 @@ object SQLExecution extends Logging { val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) withSQLConfPropagated(sparkSession) { - withSessionTagsApplied(sparkSession) { - var ex: Option[Throwable] = None - var isExecutedPlanAvailable = false - val startTime = System.nanoTime() - val startEvent = SparkListenerSQLExecutionStart( - executionId = executionId, - rootExecutionId = Some(rootExecutionId), - description = desc, - details = callSite.longForm, - physicalPlanDescription = "", - sparkPlanInfo = SparkPlanInfo.EMPTY, - time = System.currentTimeMillis(), - modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags(), - jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) - try { - body match { - case Left(e) => - sc.listenerBus.post(startEvent) + sparkSession.artifactManager.withResources { + withSessionTagsApplied(sparkSession) { + var ex: Option[Throwable] = None + var isExecutedPlanAvailable = false + val startTime = System.nanoTime() + val startEvent = SparkListenerSQLExecutionStart( + executionId = executionId, + rootExecutionId = Some(rootExecutionId), + description = desc, + details = callSite.longForm, + physicalPlanDescription = "", + sparkPlanInfo = SparkPlanInfo.EMPTY, + time = System.currentTimeMillis(), + modifiedConfigs = redactedConfigs, + jobTags = sc.getJobTags(), + jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) + try { + body match { + case Left(e) => + sc.listenerBus.post(startEvent) + throw e + case Right(f) => + val planDescriptionMode = + ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val planDesc = queryExecution.explainString(planDescriptionMode) + val planInfo = try { + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) + } catch { + case NonFatal(e) => + logDebug("Failed to generate SparkPlanInfo", e) + // If the queryExecution already failed before this, we are not able to + // generate the the plan info, so we use and empty graphviz node to make the + // UI happy + SparkPlanInfo.EMPTY + } + sc.listenerBus.post( + startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) + isExecutedPlanAvailable = true + f() + } + } catch { + case e: Throwable => + ex = Some(e) throw e - case Right(f) => - val planDescriptionMode = - ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) - val planDesc = queryExecution.explainString(planDescriptionMode) - val planInfo = try { - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) - } catch { - case NonFatal(e) => - logDebug("Failed to generate SparkPlanInfo", e) - // If the queryExecution already failed before this, we are not able to generate - // the the plan info, so we use and empty graphviz node to make the UI happy - SparkPlanInfo.EMPTY - } - sc.listenerBus.post( - startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) - isExecutedPlanAvailable = true - f() - } - } catch { - case e: Throwable => - ex = Some(e) - throw e - } finally { - val endTime = System.nanoTime() - val errorMessage = ex.map { - case e: SparkThrowable => - SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) - case e => - Utils.exceptionString(e) - } - if (queryExecution.shuffleCleanupMode != DoNotCleanup - && isExecutedPlanAvailable) { - val shuffleIds = queryExecution.executedPlan match { - case ae: AdaptiveSparkPlanExec => - ae.context.shuffleIds.asScala.keys - case _ => - Iterable.empty + } finally { + val endTime = System.nanoTime() + val errorMessage = ex.map { + case e: SparkThrowable => + SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) + case e => + Utils.exceptionString(e) } - shuffleIds.foreach { shuffleId => - queryExecution.shuffleCleanupMode match { - case RemoveShuffleFiles => - // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister - // the shuffle on MapOutputTracker, so that stage retries would be triggered. - // Set blocking to Utils.isTesting to deflake unit tests. - sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) - case SkipMigration => - SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) - case _ => // this should not happen + if (queryExecution.shuffleCleanupMode != DoNotCleanup + && isExecutedPlanAvailable) { + val shuffleIds = queryExecution.executedPlan match { + case ae: AdaptiveSparkPlanExec => + ae.context.shuffleIds.asScala.keys + case _ => + Iterable.empty + } + shuffleIds.foreach { shuffleId => + queryExecution.shuffleCleanupMode match { + case RemoveShuffleFiles => + // Same as what we do in ContextCleaner.doCleanupShuffle, but do not + // unregister the shuffle on MapOutputTracker, so that stage retries would be + // triggered. + // Set blocking to Utils.isTesting to deflake unit tests. + sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) + case SkipMigration => + SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) + case _ => // this should not happen + } } } + val event = SparkListenerSQLExecutionEnd( + executionId, + System.currentTimeMillis(), + // Use empty string to indicate no error, as None may mean events generated by old + // versions of Spark. + errorMessage.orElse(Some(""))) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the + // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with + // name. We can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } - val event = SparkListenerSQLExecutionEnd( - executionId, - System.currentTimeMillis(), - // Use empty string to indicate no error, as None may mean events generated by old - // versions of Spark. - errorMessage.orElse(Some(""))) - // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the - // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with - // name. We can specify the execution name in more places in the future, so that - // `QueryExecutionListener` can track more cases. - event.executionName = name - event.duration = endTime - startTime - event.qe = queryExecution - event.executionFailure = ex - sc.listenerBus.post(event) } } } @@ -301,7 +305,10 @@ object SQLExecution extends Logging { val activeSession = sparkSession val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) - val artifactState = JobArtifactSet.getCurrentJobArtifactState.orNull + // `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect + // mode, we default back to the resources of the current Spark session. + val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse( + activeSession.artifactManager.state) exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index d8f32a2cb9225..bd501c9357234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -223,9 +223,7 @@ abstract class StreamExecution( // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread sparkSession.sparkContext.setCallSite(callSite) - JobArtifactSet.withActiveJobArtifactState(jobArtifactState) { - runStream() - } + runStream() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f22d4fe326689..59a873ef982fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -181,7 +181,7 @@ abstract class BaseSessionStateBuilder( * Note 1: The user-defined functions must be deterministic. * Note 2: This depends on the `functionRegistry` field. */ - protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry) + protected def udfRegistration: UDFRegistration = new UDFRegistration(session, functionRegistry) protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala index e929a6b5303a5..e935af8b8bf8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala @@ -24,8 +24,8 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.DataTypes import org.apache.spark.storage.CacheId @@ -36,6 +36,8 @@ class ArtifactManagerSuite extends SharedSparkSession { override protected def sparkConf: SparkConf = { val conf = super.sparkConf conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") + conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, true) + conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER, true) } private val artifactPath = new File("src/test/resources/artifact-tests").toPath @@ -331,24 +333,17 @@ class ArtifactManagerSuite extends SharedSparkSession { } } - test("Add UDF as artifact") { + test("Added artifact can be loaded by the current SparkSession") { val buffer = Files.readAllBytes(artifactPath.resolve("IntSumUdf.class")) spark.addArtifact(buffer, "IntSumUdf.class") - val instance = artifactManager.classloader - .loadClass("IntSumUdf") - .getDeclaredConstructor() - .newInstance() - .asInstanceOf[UDF2[Long, Long, Long]] - spark.udf.register("intSum", instance, DataTypes.LongType) - - artifactManager.withResources { - val r = spark.range(5) - .withColumn("id2", col("id") + 1) - .selectExpr("intSum(id, id2)") - .collect() - assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9)) - } + spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType) + + val r = spark.range(5) + .withColumn("id2", col("id") + 1) + .selectExpr("intSum(id, id2)") + .collect() + assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9)) } private def testAddArtifactToLocalSession( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fec7183bc75e6..32a63f5c61976 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2141,7 +2141,12 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { root = Utils.createTempDir().getCanonicalPath, namePrefix = "addDirectory") val testFile = File.createTempFile("testFile", "1", directoryToAdd) spark.sql(s"ADD FILE $directoryToAdd") - assert(new File(SparkFiles.get(s"${directoryToAdd.getName}/${testFile.getName}")).exists()) + // TODO(SPARK-50244): ADD JAR is inside `sql()` thus isolated. This will break an existing Hive + // use case (one session adds JARs and another session uses them). After we sort out the Hive + // isolation issue we will decide if the next assert should be wrapped inside `withResources`. + spark.artifactManager.withResources { + assert(new File(SparkFiles.get(s"${directoryToAdd.getName}/${testFile.getName}")).exists()) + } } test(s"Add a directory when ${SQLConf.LEGACY_ADD_SINGLE_FILE_IN_ADD_FILE.key} set to true") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 42fc50e5b163b..c41370c96241a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -70,6 +70,14 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } } + override def afterEach(): Unit = { + try { + spark.artifactManager.cleanUpResources() + } finally { + super.afterEach() + } + } + private def assertUnsupportedFeature( body: => Unit, operation: String, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9611a37ef0d06..247a1c7096cb7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -630,7 +630,11 @@ private[hive] object TestHiveContext { val overrideConfs: Map[String, String] = Map( // Fewer shuffle partitions to speed up testing. - SQLConf.SHUFFLE_PARTITIONS.key -> "5" + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR` command. This will break + // an existing Hive use case (one session adds JARs and another session uses them). We need + // to decide whether/how to enable isolation for Hive. + SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key -> "false" ) def makeWarehouseDir(): File = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d50bf0b8fd603..770e1da94a1c7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -40,4 +40,11 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { } } + protected override def afterEach(): Unit = { + try { + spark.artifactManager.cleanUpResources() + } finally { + super.afterEach() + } + } } From a84ca5e46e638a3a4d274d06e8e0425cbe0f3f37 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 13 Nov 2024 10:26:17 +0800 Subject: [PATCH 14/39] [SPARK-50241][SQL] Replace NullIntolerant Mixin with Expression.nullIntolerant method ### What changes were proposed in this pull request? Replace NullIntolerant Mixin with Expression.nullIntolerant method ### Why are the changes needed? https://github.com/apache/spark/pull/48758#issuecomment-2458713378 via cloud-fan > This is not the first time that we are restricted by the trait-based tagging system. Extending a trait is static as it happens at compile but we really want to be more dynamic at runtime. For example, we added Expression#stateful and removed the trait StatefulExpression a while ago. I think the same thing happens to NullIntelerant now. Shall we also add Expression#nullIntolerant? ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Modified ExpressionInfoSuite, - Check whether SQL expressions should extend NullIntolerant ### Was this patch authored or co-authored using generative AI tooling? No Closes #48772 from yaooqinn/SPARK-50241. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 10 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../catalyst/expressions/ToJavaArray.scala | 2 +- .../sql/catalyst/expressions/TryEval.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 13 ++- .../expressions/bitwiseExpressions.scala | 10 +- .../expressions/collectionOperations.scala | 60 ++++++---- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/complexTypeExtractors.scala | 3 +- .../catalyst/expressions/csvExpressions.scala | 8 +- .../expressions/datasketchesExpressions.scala | 8 +- .../expressions/datetimeExpressions.scala | 90 +++++++++------ .../expressions/decimalExpressions.scala | 7 +- .../spark/sql/catalyst/expressions/hash.scala | 12 +- .../expressions/intervalExpressions.scala | 32 ++++-- .../expressions/mathExpressions.scala | 27 +++-- .../expressions/namedExpressions.scala | 3 +- .../expressions/numberFormatExpressions.scala | 7 +- .../expressions/objects/objects.scala | 46 +------- .../sql/catalyst/expressions/package.scala | 7 -- .../sql/catalyst/expressions/predicates.scala | 27 ++--- .../expressions/regexpExpressions.scala | 17 +-- .../expressions/stringExpressions.scala | 107 +++++++++++------- .../variant/variantExpressions.scala | 5 +- .../sql/catalyst/expressions/xml/xpath.scala | 3 +- .../catalyst/expressions/xmlExpressions.scala | 6 +- .../sql/catalyst/optimizer/expressions.scala | 8 +- .../plans/logical/QueryPlanConstraints.scala | 2 +- .../sql/expressions/ExpressionInfoSuite.scala | 15 +-- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 6 +- 31 files changed, 302 insertions(+), 255 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9a29cb4a2bfb3..0df5a70198932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -481,9 +481,9 @@ case class Cast( extends UnaryExpression with TimeZoneAwareExpression with ToStringBase - with NullIntolerant with SupportQueryContext with QueryErrorsBase { + override def nullIntolerant: Boolean = true def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) = this(child, dataType, timeZoneId, evalMode = EvalMode.fromSQLConf(SQLConf.get)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 140335ef8bdd6..f0f94f0881385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -53,8 +53,6 @@ import org.apache.spark.sql.types._ * - [[Unevaluable]]: an expression that is not supposed to be evaluated. * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to * interpreted mode. - * - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in - * null output). * - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL * expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`, * and object `MapObjects` and `Invoke`. @@ -141,6 +139,14 @@ abstract class Expression extends TreeNode[Expression] { */ def stateful: Boolean = false + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + def nullIntolerant: Boolean = false + /** * Returns true if the expression could potentially throw an exception when evaluated. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 673f9397bb03f..65d9e238eb502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -237,7 +237,8 @@ object TimeWindow { case class PreciseTimestampConversion( child: Expression, fromType: DataType, - toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + toType: DataType) extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala index 861d7ff4024a3..8cf3cdef16c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala @@ -36,7 +36,6 @@ import org.apache.spark.util.Utils */ case class ToJavaArray(array: Expression) extends UnaryExpression - with NullIntolerant with RuntimeReplaceable with QueryErrorsBase { @@ -55,6 +54,7 @@ case class ToJavaArray(array: Expression) } override def foldable: Boolean = array.foldable + override def nullIntolerant: Boolean = true override def child: Expression = array override def prettyName: String = "to_java_array" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index b7d0ffdb75fb0..7c84773006c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, NumericType} -case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant { +case class TryEval(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) ev.copy(code = code""" @@ -48,6 +48,7 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran override def dataType: DataType = child.dataType override def nullable: Boolean = true + override def nullIntolerant: Boolean = true override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 497fdc0936267..015240472cf5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -46,7 +46,8 @@ import org.apache.spark.unsafe.types.CalendarInterval case class UnaryMinus( child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(child: Expression) = this(child, SQLConf.get.ansiEnabled) @@ -114,7 +115,8 @@ case class UnaryMinus( since = "1.5.0", group = "math_funcs") case class UnaryPositive(child: Expression) - extends RuntimeReplaceable with ImplicitCastInputTypes with NullIntolerant { + extends RuntimeReplaceable with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def prettyName: String = "positive" @@ -148,7 +150,8 @@ case class UnaryPositive(child: Expression) since = "1.2.0", group = "math_funcs") case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(child: Expression) = this(child, SQLConf.get.ansiEnabled) @@ -185,8 +188,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild) } -abstract class BinaryArithmetic extends BinaryOperator - with NullIntolerant with SupportQueryContext { +abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext { + override def nullIntolerant: Boolean = true protected val evalMode: EvalMode.Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 88085636a5ff1..26743ca6ff15e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -176,8 +176,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme since = "1.4.0", group = "bitwise_funcs") case class BitwiseNot(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { - + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) override def dataType: DataType = child.dataType @@ -218,7 +218,8 @@ case class BitwiseNot(child: Expression) since = "3.0.0", group = "bitwise_funcs") case class BitwiseCount(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) @@ -269,7 +270,8 @@ object BitwiseGetUtil { since = "3.2.0", group = "bitwise_funcs") case class BitwiseGet(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType, IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 10e64626d1a1b..fb130574d3474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -188,7 +188,8 @@ case class ArraySize(child: Expression) group = "map_funcs", since = "2.0.0") case class MapKeys(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -477,7 +478,8 @@ object ArraysZip { group = "map_funcs", since = "2.0.0") case class MapValues(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -510,7 +512,8 @@ case class MapValues(child: Expression) group = "map_funcs", since = "3.0.0") case class MapEntries(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -813,8 +816,8 @@ case class MapConcat(children: Seq[Expression]) since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression - with NullIntolerant with QueryErrorsBase { + override def nullIntolerant: Boolean = true @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -894,7 +897,8 @@ case class MapFromEntries(child: Expression) } case class MapSort(base: Expression) - extends UnaryExpression with NullIntolerant with QueryErrorsBase { + extends UnaryExpression with QueryErrorsBase { + override def nullIntolerant: Boolean = true val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType @@ -1048,7 +1052,8 @@ case class MapSort(base: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant with QueryErrorsBase { + extends BinaryExpression with ExpectsInputTypes with QueryErrorsBase { + override def nullIntolerant: Boolean = true def this(e: Expression) = this(e, Literal(true)) @@ -1345,8 +1350,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U """ ) case class Reverse(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringTypeWithCollation, ArrayType)) @@ -1421,8 +1426,9 @@ case class Reverse(child: Expression) group = "array_funcs", since = "1.5.0") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate + extends BinaryExpression with ImplicitCastInputTypes with Predicate with QueryErrorsBase { + override def nullIntolerant: Boolean = true @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1551,9 +1557,9 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayBinarySearch(array: Expression, value: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with RuntimeReplaceable with QueryErrorsBase { + override def nullIntolerant: Boolean = true override def left: Expression = array override def right: Expression = value @@ -1764,7 +1770,8 @@ case class ArrayAppend(left: Expression, right: Expression) extends ArrayPendBas since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryArrayExpressionWithImplicitCast with NullIntolerant with Predicate { + extends BinaryArrayExpressionWithImplicitCast with Predicate { + override def nullIntolerant: Boolean = true override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => @@ -1990,7 +1997,8 @@ case class ArraysOverlap(left: Expression, right: Expression) since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = x.dataType @@ -2289,8 +2297,8 @@ case class ArrayJoin( group = "array_funcs", since = "2.4.0") case class ArrayMin(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -2362,8 +2370,8 @@ case class ArrayMin(child: Expression) group = "array_funcs", since = "2.4.0") case class ArrayMax(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -2444,7 +2452,8 @@ case class ArrayMax(child: Expression) group = "array_funcs", since = "2.4.0") case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with QueryErrorsBase { + extends BinaryExpression with ImplicitCastInputTypes with QueryErrorsBase { + override def nullIntolerant: Boolean = true @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -2592,8 +2601,9 @@ case class ElementAt( // The value to return if index is out of bound defaultValueOutOfBound: Option[Literal] = None, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant with SupportQueryContext + extends GetMapValueUtil with GetArrayItemUtil with SupportQueryContext with QueryErrorsBase { + override def nullIntolerant: Boolean = true def this(left: Expression, right: Expression) = this(left, right, None, SQLConf.get.ansiEnabled) @@ -3055,8 +3065,9 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio """, group = "array_funcs", since = "2.4.0") -case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant +case class Flatten(child: Expression) extends UnaryExpression with QueryErrorsBase { + override def nullIntolerant: Boolean = true private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -3941,8 +3952,8 @@ case class ArrayRepeat(left: Expression, right: Expression) group = "array_funcs", since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with QueryErrorsBase { - + extends BinaryExpression with ImplicitCastInputTypes with QueryErrorsBase { + override def nullIntolerant: Boolean = true override def dataType: DataType = left.dataType override def inputTypes: Seq[AbstractDataType] = { @@ -4153,8 +4164,8 @@ trait ArraySetLike { group = "array_funcs", since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant { - + extends UnaryExpression with ArraySetLike with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def dataType: DataType = child.dataType @@ -4317,7 +4328,8 @@ case class ArrayDistinct(child: Expression) * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ trait ArrayBinaryLike - extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant { + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { + override def nullIntolerant: Boolean = true override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b8b47f2763f5b..2098ee274dfe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -301,8 +301,8 @@ object CreateMap { since = "2.4.0", group = "map_funcs") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant { - + extends BinaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { @@ -562,8 +562,8 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with group = "map_funcs") // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) - extends TernaryExpression with ExpectsInputTypes with NullIntolerant { - + extends TernaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true def this(child: Expression, pairDelim: Expression) = { this(child, pairDelim, Literal(":")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index ff94322efdaa4..3b8d4e09905e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -90,7 +90,8 @@ object ExtractValue { } } -trait ExtractValue extends Expression with NullIntolerant { +trait ExtractValue extends Expression { + override def nullIntolerant: Boolean = true final override val nodePatterns: Seq[TreePattern] = Seq(EXTRACT_VALUE) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 5393b2bde93b0..e9cdc184e55a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -60,9 +60,8 @@ case class CsvToStructs( extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback - with ExpectsInputTypes - with NullIntolerant { - + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def nullable: Boolean = child.nullable // The CSV input data might be missing certain fields. We force the nullability @@ -238,7 +237,8 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala index fa917dfc5c83f..a4ac0bdbb11d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, Da case class HllSketchEstimate(child: Expression) extends UnaryExpression with CodegenFallback - with ExpectsInputTypes - with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override protected def withNewChildInternal(newChild: Expression): HllSketchEstimate = copy(child = newChild) @@ -80,8 +80,8 @@ case class HllSketchEstimate(child: Expression) case class HllUnion(first: Expression, second: Expression, third: Expression) extends TernaryExpression with CodegenFallback - with ExpectsInputTypes - with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true // The default target type (register size) to use. private val targetType = TgtHllType.HLL_8 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 9467f1146c431..f2ba3ed95b850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -309,8 +309,8 @@ case class CurrentBatchTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant { - + extends BinaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def left: Expression = startDate override def right: Expression = days @@ -348,7 +348,8 @@ case class DateAdd(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant { + extends BinaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def left: Expression = startDate override def right: Expression = days @@ -374,8 +375,8 @@ case class DateSub(startDate: Expression, days: Expression) } trait GetTimeField extends UnaryExpression - with TimeZoneAwareExpression with ImplicitCastInputTypes with NullIntolerant { - + with TimeZoneAwareExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true val func: (Long, ZoneId) => Any val funcName: String @@ -461,7 +462,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No copy(child = newChild) } -trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { +trait GetDateField extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true val func: Int => Any val funcName: String @@ -505,7 +507,9 @@ case class DayOfYear(child: Expression) extends GetDateField { group = "datetime_funcs", since = "3.1.0") case class DateFromUnixDate(child: Expression) extends UnaryExpression - with ImplicitCastInputTypes with NullIntolerant { + with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) override def dataType: DataType = DateType @@ -531,7 +535,9 @@ case class DateFromUnixDate(child: Expression) extends UnaryExpression group = "datetime_funcs", since = "3.1.0") case class UnixDate(child: Expression) extends UnaryExpression - with ExpectsInputTypes with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) override def dataType: DataType = IntegerType @@ -548,7 +554,8 @@ case class UnixDate(child: Expression) extends UnaryExpression } abstract class IntegralToTimestampBase extends UnaryExpression - with ExpectsInputTypes with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true protected def upScaleFactor: Long @@ -583,7 +590,8 @@ abstract class IntegralToTimestampBase extends UnaryExpression since = "3.1.0") // scalastyle:on line.size.limit case class SecondsToTimestamp(child: Expression) extends UnaryExpression - with ExpectsInputTypes with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -682,7 +690,8 @@ case class MicrosToTimestamp(child: Expression) } abstract class TimestampToLongBase extends UnaryExpression - with ExpectsInputTypes with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true protected def scaleFactor: Long @@ -954,8 +963,8 @@ case class DayName(child: Expression) extends GetDateField { since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes - with NullIntolerant { + extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(left: Expression, right: Expression) = this(left, right, None) @@ -1427,8 +1436,8 @@ abstract class UnixTime extends ToTimestamp { since = "1.5.0") // scalastyle:on line.size.limit case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes - with NullIntolerant { + extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(sec: Expression, format: Expression) = this(sec, format, None) @@ -1497,7 +1506,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ group = "datetime_funcs", since = "1.5.0") case class LastDay(startDate: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true + override def child: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -1546,7 +1557,8 @@ case class NextDay( startDate: Expression, dayOfWeek: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def left: Expression = startDate override def right: Expression = dayOfWeek @@ -1628,7 +1640,8 @@ case class NextDay( * Adds an interval to timestamp. */ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -1709,7 +1722,8 @@ case class DateAddInterval( interval: Expression, timeZoneId: Option[String] = None, ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant { + extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression { + override def nullIntolerant: Boolean = true override def left: Expression = start override def right: Expression = interval @@ -1761,7 +1775,8 @@ case class DateAddInterval( copy(start = newLeft, interval = newRight) } -sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { +sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true val func: (Long, String) => Long val funcName: String @@ -1879,8 +1894,8 @@ case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimest copy(left = newLeft, right = newRight) } -abstract class AddMonthsBase extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant { +abstract class AddMonthsBase extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = DateType override def nullSafeEval(start: Any, months: Any): Any = { @@ -1942,7 +1957,8 @@ case class TimestampAddYMInterval( timestamp: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true def this(timestamp: Expression, interval: Expression) = this(timestamp, interval, None) @@ -2008,8 +2024,8 @@ case class MonthsBetween( date2: Expression, roundOff: Expression, timeZoneId: Option[String] = None) - extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes - with NullIntolerant { + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) @@ -2427,7 +2443,8 @@ case class TruncTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def left: Expression = endDate override def right: Expression = startDate @@ -2471,7 +2488,8 @@ case class MakeDate( month: Expression, day: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(year: Expression, month: Expression, day: Expression) = this(year, month, day, SQLConf.get.ansiEnabled) @@ -2750,8 +2768,8 @@ case class MakeTimestamp( timeZoneId: Option[String] = None, failOnError: Boolean = SQLConf.get.ansiEnabled, override val dataType: DataType = SQLConf.get.timestampType) - extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes - with NullIntolerant { + extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this( year: Expression, @@ -3171,8 +3189,8 @@ case class SubtractTimestamps( timeZoneId: Option[String] = None) extends BinaryExpression with TimeZoneAwareExpression - with ExpectsInputTypes - with NullIntolerant { + with ExpectsInputTypes { + override def nullIntolerant: Boolean = true def this(endTimestamp: Expression, startTimestamp: Expression) = this(endTimestamp, startTimestamp, SQLConf.get.legacyIntervalEnabled) @@ -3234,7 +3252,8 @@ case class SubtractDates( left: Expression, right: Expression, legacyInterval: Boolean) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.legacyIntervalEnabled) @@ -3303,7 +3322,8 @@ case class ConvertTimezone( sourceTz: Expression, targetTz: Expression, sourceTs: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(targetTz: Expression, sourceTs: Expression) = this(CurrentTimeZone(), targetTz, sourceTs) @@ -3380,8 +3400,8 @@ case class TimestampAdd( timeZoneId: Option[String] = None) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with TimeZoneAwareExpression { + override def nullIntolerant: Boolean = true def this(unit: String, quantity: Expression, timestamp: Expression) = this(unit, quantity, timestamp, None) @@ -3469,8 +3489,8 @@ case class TimestampDiff( timeZoneId: Option[String] = None) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with TimeZoneAwareExpression { + override def nullIntolerant: Boolean = true def this(unit: String, startTimestamp: Expression, endTimestamp: Expression) = this(unit, startTimestamp, endTimestamp, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index f7509f124ab50..46ab43074409a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { - +case class UnscaledValue(child: Expression) extends UnaryExpression { + override def nullIntolerant: Boolean = true override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -57,7 +57,8 @@ case class MakeDecimal( child: Expression, precision: Int, scale: Int, - nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { + nullOnOverflow: Boolean) extends UnaryExpression { + override def nullIntolerant: Boolean = true def this(child: Expression, precision: Int, scale: Int) = { this(child, precision, scale, !SQLConf.get.ansiEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7128190902550..79879dc0edb4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -61,7 +61,8 @@ import org.apache.spark.util.ArrayImplicits._ since = "1.5.0", group = "hash_funcs") case class Md5(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType @@ -101,7 +102,8 @@ case class Md5(child: Expression) group = "hash_funcs") // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true @@ -167,7 +169,8 @@ case class Sha2(left: Expression, right: Expression) since = "1.5.0", group = "hash_funcs") case class Sha1(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType @@ -199,7 +202,8 @@ case class Sha1(child: Expression) since = "1.5.0", group = "hash_funcs") case class Crc32(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5363f1bba390a..1ce7dfd39acc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -37,7 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval abstract class ExtractIntervalPart[T]( val dataType: DataType, func: T => Any, - funcName: String) extends UnaryExpression with NullIntolerant with Serializable { + funcName: String) extends UnaryExpression with Serializable { + override def nullIntolerant: Boolean = true + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iu = IntervalUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$iu.$funcName($c)") @@ -168,7 +170,9 @@ object ExtractIntervalPart { abstract class IntervalNumOperation( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true + override def left: Expression = interval override def right: Expression = num @@ -341,7 +345,8 @@ case class MakeInterval( mins: Expression, secs: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled) - extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends SeptenaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this( years: Expression, @@ -476,7 +481,8 @@ case class MakeDTInterval( hours: Expression, mins: Expression, secs: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends QuaternaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this( days: Expression, @@ -550,7 +556,8 @@ case class MakeDTInterval( group = "datetime_funcs") // scalastyle:on line.size.limit case class MakeYMInterval(years: Expression, months: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true def this(years: Expression) = this(years, Literal(0)) def this() = this(Literal(0)) @@ -586,7 +593,8 @@ case class MakeYMInterval(years: Expression, months: Expression) case class MultiplyYMInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def left: Expression = interval override def right: Expression = num @@ -638,7 +646,8 @@ case class MultiplyYMInterval( case class MultiplyDTInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def left: Expression = interval override def right: Expression = num @@ -724,8 +733,8 @@ trait IntervalDivide { case class DivideYMInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide - with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide with Serializable { + override def nullIntolerant: Boolean = true override def left: Expression = interval override def right: Expression = num @@ -806,8 +815,9 @@ case class DivideYMInterval( case class DivideDTInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide - with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide with Serializable { + override def nullIntolerant: Boolean = true + override def left: Expression = interval override def right: Expression = num diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 71fd43a8d9423..30f07dcc1e67e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -63,7 +63,8 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(val f: Double => Double, name: String) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends UnaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -117,7 +118,8 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -443,8 +445,8 @@ case class Conv( ansiEnabled: Boolean = SQLConf.get.ansiEnabled) extends TernaryExpression with ImplicitCastInputTypes - with NullIntolerant with SupportQueryContext { + override def nullIntolerant: Boolean = true def this(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) = this(numExpr, fromBaseExpr, toBaseExpr, ansiEnabled = SQLConf.get.ansiEnabled) @@ -629,7 +631,8 @@ object Factorial { since = "1.5.0", group = "math_funcs") case class Factorial(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -1002,8 +1005,8 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia group = "math_funcs") // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { - + extends UnaryExpression with ImplicitCastInputTypes with Serializable { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = SQLConf.get.defaultStringType @@ -1111,7 +1114,8 @@ object Hex { since = "1.5.0", group = "math_funcs") case class Hex(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation)) @@ -1154,7 +1158,8 @@ case class Hex(child: Expression) since = "1.5.0", group = "math_funcs") case class Unhex(child: Expression, failOnError: Boolean = false) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(expr: Expression) = this(expr, false) @@ -1251,7 +1256,8 @@ case class Pow(left: Expression, right: Expression) } sealed trait BitShiftOperation - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def symbol: String def shiftInt: (Int, Int) => Int @@ -1832,7 +1838,8 @@ case class WidthBucket( minValue: Expression, maxValue: Expression, numBucket: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends QuaternaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq( TypeCollection(DoubleType, YearMonthIntervalType, DayTimeIntervalType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3258a57bb1236..f5f35050401ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -104,7 +104,8 @@ trait NamedExpression extends Expression { def newInstance(): NamedExpression } -abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { +abstract class Attribute extends LeafExpression with NamedExpression { + override def nullIntolerant: Boolean = true @transient override lazy val references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 0d137a9b8f6e5..d4dcfdc5e72fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, Datet import org.apache.spark.unsafe.types.UTF8String abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Boolean) - extends BinaryExpression with Serializable with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with Serializable with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true private lazy val numberFormatter = { val value = right.eval() @@ -273,7 +274,9 @@ object ToCharacterBuilder extends ExpressionBuilder { } case class ToCharacter(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true + private lazy val numberFormatter = { val value = right.eval() if (value != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f49fd697492a2..1d4b66557e478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -332,6 +332,7 @@ case class StaticInvoke( } override def nullable: Boolean = needNullCheck || returnNullable + override def nullIntolerant: Boolean = propagateNull override def children: Seq[Expression] = arguments override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) @@ -412,28 +413,6 @@ case class StaticInvoke( }.$functionName(${arguments.mkString(", ")}))" } -object StaticInvoke { - def withNullIntolerant( - staticObject: Class[_], - dataType: DataType, - functionName: String, - arguments: Seq[Expression] = Nil, - inputTypes: Seq[AbstractDataType] = Nil, - propagateNull: Boolean = true, - returnNullable: Boolean = true, - isDeterministic: Boolean = true, - scalarFunction: Option[ScalarFunction[_]] = None): StaticInvoke = - new StaticInvoke( - staticObject, - dataType, - functionName, - arguments, - inputTypes, - propagateNull, - returnNullable, - isDeterministic, scalarFunction) with NullIntolerant -} - /** * Calls the specified function on an object, optionally passing arguments. If the `targetObject` * expression evaluates to null then null will be returned. @@ -470,7 +449,7 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true, isDeterministic: Boolean = true) extends InvokeLike { - + override def nullIntolerant: Boolean = propagateNull lazy val argClasses = EncoderUtils.expressionJavaClasses(arguments) final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE) @@ -577,27 +556,6 @@ case class Invoke( copy(targetObject = newChildren.head, arguments = newChildren.tail) } -object Invoke { - def withNullIntolerant( - targetObject: Expression, - functionName: String, - dataType: DataType, - arguments: Seq[Expression] = Nil, - methodInputTypes: Seq[AbstractDataType] = Nil, - propagateNull: Boolean = true, - returnNullable: Boolean = true, - isDeterministic: Boolean = true): Invoke = - new Invoke( - targetObject, - functionName, - dataType, - arguments, - methodInputTypes, - propagateNull, - returnNullable, - isDeterministic) with NullIntolerant -} - object NewInstance { def apply( cls: Class[_], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 20105b87004f4..86d3cee6a0600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -405,11 +405,4 @@ package object expressions { } } } - - /** - * When an expression inherits this, meaning the expression is null intolerant (i.e. any null - * input will result in null output). We will use this information during constructing IsNotNull - * constraints. - */ - trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5897be12db193..986bc63363d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -269,10 +269,8 @@ trait PredicateHelper extends AliasHelper with Logging { } // If one expression and its children are null intolerant, it is null intolerant. - protected def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false - } + protected def isNullIntolerant(expr: Expression): Boolean = + expr.nullIntolerant && expr.children.forall(isNullIntolerant) protected def outputWithNullability( output: Seq[Attribute], @@ -317,7 +315,8 @@ trait PredicateHelper extends AliasHelper with Logging { since = "1.0.0", group = "predicate_funcs") case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with Predicate with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def toString: String = s"NOT $child" @@ -1070,8 +1069,8 @@ object Equality { since = "1.0.0", group = "predicate_funcs") case class EqualTo(left: Expression, right: Expression) - extends BinaryComparison with NullIntolerant { - + extends BinaryComparison { + override def nullIntolerant: Boolean = true override def symbol: String = "=" // +---------+---------+---------+---------+ @@ -1219,8 +1218,8 @@ case class EqualNull(left: Expression, right: Expression, replacement: Expressio since = "1.0.0", group = "predicate_funcs") case class LessThan(left: Expression, right: Expression) - extends BinaryComparison with NullIntolerant { - + extends BinaryComparison { + override def nullIntolerant: Boolean = true override def symbol: String = "<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) @@ -1254,8 +1253,8 @@ case class LessThan(left: Expression, right: Expression) since = "1.0.0", group = "predicate_funcs") case class LessThanOrEqual(left: Expression, right: Expression) - extends BinaryComparison with NullIntolerant { - + extends BinaryComparison { + override def nullIntolerant: Boolean = true override def symbol: String = "<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) @@ -1289,7 +1288,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) since = "1.0.0", group = "predicate_funcs") case class GreaterThan(left: Expression, right: Expression) - extends BinaryComparison with NullIntolerant { + extends BinaryComparison { + override def nullIntolerant: Boolean = true override def symbol: String = ">" @@ -1324,7 +1324,8 @@ case class GreaterThan(left: Expression, right: Expression) since = "1.0.0", group = "predicate_funcs") case class GreaterThanOrEqual(left: Expression, right: Expression) - extends BinaryComparison with NullIntolerant { + extends BinaryComparison { + override def nullIntolerant: Boolean = true override def symbol: String = ">=" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4ace3dc95a43e..8f520fca43501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -43,8 +43,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String abstract class StringRegexExpression extends BinaryExpression - with ImplicitCastInputTypes with NullIntolerant with Predicate { - + with ImplicitCastInputTypes with Predicate { + override def nullIntolerant: Boolean = true def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -290,7 +290,8 @@ case class ILike( } sealed abstract class MultiLikeBase - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate { + extends UnaryExpression with ImplicitCastInputTypes with Predicate { + override def nullIntolerant: Boolean = true protected def patterns: Seq[UTF8String] @@ -566,8 +567,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress since = "1.5.0", group = "string_funcs") case class StringSplit(str: Expression, regex: Expression, limit: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeBinaryLcase, StringTypeWithCollation, IntegerType) @@ -638,7 +639,8 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) group = "string_funcs") // scalastyle:on line.size.limit case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression, pos: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends QuaternaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(subject: Expression, regexp: Expression, rep: Expression) = this(subject, regexp, rep, Literal(1)) @@ -805,7 +807,8 @@ object RegExpExtractBase { } abstract class RegExpExtractBase - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def subject: Expression def regexp: Expression def idx: Expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8e8d3a9574667..d92f45b1968ab 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -458,7 +458,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { since = "1.0.1", group = "string_funcs") case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression with NullIntolerant { + extends UnaryExpression with String2StringExpression { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId @@ -490,7 +491,8 @@ case class Upper(child: Expression) since = "1.0.1", group = "string_funcs") case class Lower(child: Expression) - extends UnaryExpression with String2StringExpression with NullIntolerant { + extends UnaryExpression with String2StringExpression { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId @@ -514,7 +516,8 @@ case class Lower(child: Expression) /** A base trait for functions that compare two strings, returning a boolean. */ abstract class StringPredicate extends BinaryExpression - with Predicate with ImplicitCastInputTypes with NullIntolerant { + with Predicate with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId @@ -745,10 +748,10 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate since = "4.0.0", group = "string_funcs") case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes - with UnaryLike[Expression] with NullIntolerant { + with UnaryLike[Expression] { + override def nullIntolerant: Boolean = true - override lazy val replacement: Expression = - Invoke.withNullIntolerant(input, "isValid", BooleanType) + override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -794,10 +797,10 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic group = "string_funcs") // scalastyle:on case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes - with UnaryLike[Expression] with NullIntolerant { + with UnaryLike[Expression] { + override def nullIntolerant: Boolean = true - override lazy val replacement: Expression = - Invoke.withNullIntolerant(input, "makeValid", input.dataType) + override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -836,10 +839,11 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl group = "string_funcs") // scalastyle:on case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes - with UnaryLike[Expression] with NullIntolerant { + with UnaryLike[Expression] { + override def nullIntolerant: Boolean = true override lazy val replacement: Expression = - StaticInvoke.withNullIntolerant( + StaticInvoke( classOf[ExpressionImplUtils], input.dataType, "validateUTF8String", @@ -887,10 +891,11 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli group = "string_funcs") // scalastyle:on case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes - with UnaryLike[Expression] with NullIntolerant { + with UnaryLike[Expression] { + override def nullIntolerant: Boolean = true override lazy val replacement: Expression = - StaticInvoke.withNullIntolerant( + StaticInvoke( classOf[ExpressionImplUtils], input.dataType, "tryValidateUTF8String", @@ -934,7 +939,8 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im group = "string_funcs") // scalastyle:on line.size.limit case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId @@ -1028,7 +1034,8 @@ object Overlay { group = "string_funcs") // scalastyle:on line.size.limit case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends QuaternaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(str: Expression, replace: Expression, pos: Expression) = { this(str, replace, pos, Literal.create(-1, IntegerType)) @@ -1167,7 +1174,8 @@ object StringTranslate { group = "string_funcs") // scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ @@ -1246,8 +1254,8 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac group = "string_funcs") // scalastyle:on line.size.limit case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes with NullIntolerant { - + with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = @@ -1657,8 +1665,8 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non group = "string_funcs") // scalastyle:on line.size.limit case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def left: Expression = str @@ -1710,8 +1718,8 @@ case class StringInstr(str: Expression, substr: Expression) group = "string_funcs") // scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId override def dataType: DataType = strExpr.dataType @@ -1900,7 +1908,8 @@ object LPadExpressionBuilder extends PadExpressionBuilderBase { } case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def first: Expression = str override def second: Expression = len @@ -1984,8 +1993,8 @@ object RPadExpressionBuilder extends PadExpressionBuilderBase { } case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def first: Expression = str override def second: Expression = len override def third: Expression = pad @@ -2146,8 +2155,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC since = "1.5.0", group = "string_funcs") case class InitCap(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. @@ -2181,8 +2190,8 @@ case class InitCap(child: Expression) since = "1.5.0", group = "string_funcs") case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends BinaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def left: Expression = str override def right: Expression = times override def dataType: DataType = str.dataType @@ -2219,8 +2228,8 @@ case class StringRepeat(str: Expression, times: Expression) since = "1.5.0", group = "string_funcs") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -2274,7 +2283,8 @@ case class StringSpace(child: Expression) group = "string_funcs") // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends TernaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -2418,7 +2428,8 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable group = "string_funcs") // scalastyle:on line.size.limit case class Length(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq( @@ -2458,7 +2469,8 @@ case class Length(child: Expression) since = "2.3.0", group = "string_funcs") case class BitLength(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq( @@ -2501,7 +2513,8 @@ case class BitLength(child: Expression) since = "2.3.0", group = "string_funcs") case class OctetLength(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq( @@ -2551,8 +2564,7 @@ case class Levenshtein( right: Expression, threshold: Option[Expression] = None) extends Expression - with ImplicitCastInputTypes - with NullIntolerant{ + with ImplicitCastInputTypes { def this(left: Expression, right: Expression, threshold: Expression) = this(left, right, Option(threshold)) @@ -2599,6 +2611,7 @@ case class Levenshtein( } override def nullable: Boolean = children.exists(_.nullable) + override def nullIntolerant: Boolean = true override def foldable: Boolean = children.forall(_.foldable) @@ -2700,7 +2713,8 @@ case class Levenshtein( since = "1.5.0", group = "string_funcs") case class SoundEx(child: Expression) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + extends UnaryExpression with ExpectsInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType @@ -2732,8 +2746,8 @@ case class SoundEx(child: Expression) since = "1.5.0", group = "string_funcs") case class Ascii(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -2780,7 +2794,8 @@ case class Ascii(child: Expression) group = "string_funcs") // scalastyle:on line.size.limit case class Chr(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(LongType) @@ -2878,7 +2893,8 @@ object Base64 { since = "1.5.0", group = "string_funcs") case class UnBase64(child: Expression, failOnError: Boolean = false) - extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends UnaryExpression with ImplicitCastInputTypes { + override def nullIntolerant: Boolean = true override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = @@ -3317,12 +3333,14 @@ case class ToBinary( since = "1.5.0", group = "string_funcs") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes with NullIntolerant { + extends BinaryExpression with ExpectsInputTypes { override def left: Expression = x override def right: Expression = d override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true + override def nullIntolerant: Boolean = true + override def inputTypes: Seq[AbstractDataType] = Seq( NumericType, @@ -3566,11 +3584,12 @@ case class Sentences( */ case class StringSplitSQL( str: Expression, - delimiter: Expression) extends BinaryExpression with NullIntolerant { + delimiter: Expression) extends BinaryExpression { override def dataType: DataType = ArrayType(str.dataType, containsNull = false) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def left: Expression = str override def right: Expression = delimiter + override def nullIntolerant: Boolean = true override def nullSafeEval(string: Any, delimiter: Any): Any = { val strings = CollationSupport.StringSplitSQL.exec(string.asInstanceOf[UTF8String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 67cdc0aa7a958..06aec93912984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -136,9 +136,8 @@ case class IsVariantNull(child: Expression) extends UnaryExpression // scalastyle:on line.size.limit case class ToVariantObject(child: Expression) extends UnaryExpression - with NullIntolerant with QueryErrorsBase { - + override def nullIntolerant: Boolean = true override val dataType: DataType = VariantType // Only accept nested types at the root but any types can be nested inside. @@ -236,7 +235,6 @@ case class VariantGet( timeZoneId: Option[String] = None) extends BinaryExpression with TimeZoneAwareExpression - with NullIntolerant with ExpectsInputTypes with QueryErrorsBase { override def checkInputDataTypes(): TypeCheckResult = { @@ -277,6 +275,7 @@ case class VariantGet( override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" override def nullable: Boolean = true + override def nullIntolerant: Boolean = true protected override def nullSafeEval(input: Any, path: Any): Any = { VariantGet.variantGet( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index c694067e06abf..9848e062a08fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -34,9 +34,10 @@ import org.apache.spark.unsafe.types.UTF8String * This is not the world's most efficient implementation due to type conversion, but works. */ abstract class XPathExtract - extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant { + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { override def left: Expression = xml override def right: Expression = path + override def nullIntolerant: Boolean = true /** XPath expressions are always nullable, e.g. if the xml string is empty. */ override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index f3f652b393f76..196c0793e6193 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -59,7 +59,6 @@ case class XmlToStructs( extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes - with NullIntolerant with QueryErrorsBase { def this(child: Expression, schema: Expression, options: Map[String, String]) = @@ -70,6 +69,7 @@ case class XmlToStructs( timeZoneId = None) override def nullable: Boolean = true + override def nullIntolerant: Boolean = true // The XML input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. @@ -241,9 +241,9 @@ case class StructsToXml( timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression - with ExpectsInputTypes - with NullIntolerant { + with ExpectsInputTypes { override def nullable: Boolean = true + override def nullIntolerant: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 195aa7bbeec02..3eb7eb6e6b2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -894,7 +894,7 @@ object NullPropagation extends Rule[LogicalPlan] { // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. - case e: NullIntolerant if e.children.exists(isNullLiteral) => + case e if e.nullIntolerant && e.children.exists(isNullLiteral) => Literal.create(null, e.dataType) } } @@ -914,7 +914,7 @@ object NullDownPropagation extends Rule[LogicalPlan] { // Applying to `EqualTo` is too disruptive for [SPARK-32290] optimization, not supported for now. // If e has multiple children, the deterministic check is required because optimizing // IsNull(a > b) to Or(IsNull(a), IsNull(b)), for example, may cause skipping the evaluation of b - private def supportedNullIntolerant(e: NullIntolerant): Boolean = (e match { + private def supportedNullIntolerant(e: Expression): Boolean = (e match { case _: Not => true case _: GreaterThan | _: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual if e.deterministic => true @@ -925,9 +925,9 @@ object NullDownPropagation extends Rule[LogicalPlan] { _.containsPattern(NULL_CHECK), ruleId) { case q: LogicalPlan => q.transformExpressionsDownWithPruning( _.containsPattern(NULL_CHECK), ruleId) { - case IsNull(e: NullIntolerant) if supportedNullIntolerant(e) => + case IsNull(e) if e.nullIntolerant && supportedNullIntolerant(e) => e.children.map(IsNull(_): Expression).reduceLeft(Or) - case IsNotNull(e: NullIntolerant) if supportedNullIntolerant(e) => + case IsNotNull(e) if e.nullIntolerant && supportedNullIntolerant(e) => e.children.map(IsNotNull(_): Expression).reduceLeft(And) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 5769f006ccbc3..ef035eba5922c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -137,7 +137,7 @@ trait ConstraintHelper { private def scanNullIntolerantAttribute(expr: Expression): Seq[Expression] = expr match { case e: ExtractValue if isExtractOnly(e) => Seq(e) case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case e if e.nullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 6eff610433c9c..a6fc43aa087da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -288,15 +288,16 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) != superClass.getMethod("eval", classOf[InternalRow]) - val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) - if (isEvalOverrode && isNullIntolerantMixedIn) { - fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + val isNullIntolerantOverridden = clazz.getMethod("nullIntolerant") != + classOf[Expression].getMethod("nullIntolerant") + if (isEvalOverrode && isNullIntolerantOverridden) { + fail(s"${clazz.getName} should not override nullIntolerant, " + s"or add ${clazz.getName} in the ignoreSet of this test.") - } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { - fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") + } else if (!isEvalOverrode && !isNullIntolerantOverridden) { + fail(s"${clazz.getName} should override nullIntolerant.") } else { - assert((!isEvalOverrode && isNullIntolerantMixedIn) || - (isEvalOverrode && !isNullIntolerantMixedIn)) + assert((!isEvalOverrode && isNullIntolerantOverridden) || + (isEvalOverrode && !isNullIntolerantOverridden)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 0550fae3805d4..8d3379805e013 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -1633,15 +1633,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true))) val df5 = sql("SELECT name FROM h2.test.employee WHERE " + - "aes_encrypt(cast(null as string), name) is null") + "aes_encrypt(name, '1234567812345678') != 'spark'") checkFiltersRemoved(df5, false) - val expectedPlanFragment5 = "PushedFilters: [], " + val expectedPlanFragment5 = "PushedFilters: [NAME IS NOT NULL], " checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen"))) val df6 = sql("SELECT name FROM h2.test.employee WHERE " + "aes_decrypt(cast(null as binary), name) is null") - checkFiltersRemoved(df6, false) + checkFiltersRemoved(df6) // removed by null intolerant opt val expectedPlanFragment6 = "PushedFilters: [], " checkPushedInfo(df6, expectedPlanFragment6) checkAnswer(df6, Seq(Row("amy"), Row("cathy"), Row("alex"), Row("david"), Row("jen"))) From dc6fba58b9fc844205c626c0b629e1d054fece33 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Wed, 13 Nov 2024 15:25:49 +0900 Subject: [PATCH 15/39] [SPARK-50290][SQL] Add a flag to disable DataFrameQueryContext creation ### What changes were proposed in this pull request? Add a new `spark.sql.dataFrameQueryContext.enabled` flag to disable the `DataFrameQueryContext` creation. ### Why are the changes needed? `DataFrameQueryContext` creation requires a stack trace. Stack trace collection has a non-trivial perf overhead, some users might want to disable this feature. `spark.sql.dataFrameQueryContext.enabled` == `true`: ![image](https://github.com/user-attachments/assets/0fe16b9e-56b8-4ff5-82b5-1e5875c6d26b) `spark.sql.dataFrameQueryContext.enabled` == `false`: ![image](https://github.com/user-attachments/assets/5721cd26-5016-422b-bf06-fa63072bae2d) ### Does this PR introduce _any_ user-facing change? No, the default is still to collect the query context. ### How was this patch tested? Added a new test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48827 from vladimirg-db/vladimirg-db/do-not-collect-df-stacktrace-if-the-setting-is-zero. Authored-by: Vladimir Golubev Signed-off-by: Hyukjin Kwon --- .../spark/sql/catalyst/trees/origin.scala | 5 +++- .../spark/sql/internal/SqlApiConf.scala | 2 ++ .../apache/spark/sql/internal/SQLConf.scala | 11 ++++++++ .../spark/sql/errors/QueryContextSuite.scala | 25 ++++++++++++++----- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index 33fa17433abbd..563554d506c4a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -101,13 +101,16 @@ object CurrentOrigin { * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user * code. * + * `withOrigin` has non-trivial performance overhead, since it collects a stack trace. This + * feature can be disabled by setting "spark.sql.dataFrameQueryContext.enabled" to "false". + * * @param f * The function that can use the origin. * @return * The result of `f`. */ private[sql] def withOrigin[T](f: => T): T = { - if (CurrentOrigin.get.stackTrace.isDefined) { + if (CurrentOrigin.get.stackTrace.isDefined || !SqlApiConf.get.dataFrameQueryContextEnabled) { f } else { val st = Thread.currentThread().getStackTrace diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index 9908021592e10..773494f418659 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -45,6 +45,7 @@ private[sql] trait SqlApiConf { def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value def defaultStringType: StringType def stackTracesInDataFrameContext: Int + def dataFrameQueryContextEnabled: Boolean def legacyAllowUntypedScalaUDFs: Boolean def allowReadingUnknownCollations: Boolean } @@ -86,6 +87,7 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf { override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED override def defaultStringType: StringType = StringType override def stackTracesInDataFrameContext: Int = 1 + override def dataFrameQueryContextEnabled: Boolean = true override def legacyAllowUntypedScalaUDFs: Boolean = false override def allowReadingUnknownCollations: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eac89212b9da8..5218a683a8fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5209,6 +5209,15 @@ object SQLConf { .checkValue(_ > 0, "The number of stack traces in the DataFrame context must be positive.") .createWithDefault(1) + val DATA_FRAME_QUERY_CONTEXT_ENABLED = buildConf("spark.sql.dataFrameQueryContext.enabled") + .internal() + .doc( + "Enable the DataFrame query context. This feature is enabled by default, but has a " + + "non-trivial performance overhead because of the stack trace collection.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val LEGACY_JAVA_CHARSETS = buildConf("spark.sql.legacy.javaCharsets") .internal() .doc("When set to true, the functions like `encode()` can use charsets from JDK while " + @@ -6232,6 +6241,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { override def stackTracesInDataFrameContext: Int = getConf(SQLConf.STACK_TRACES_IN_DATAFRAME_CONTEXT) + def dataFrameQueryContextEnabled: Boolean = getConf(SQLConf.DATA_FRAME_QUERY_CONTEXT_ENABLED) + override def legacyAllowUntypedScalaUDFs: Boolean = getConf(SQLConf.LEGACY_ALLOW_UNTYPED_SCALA_UDF) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala index 426822da3c912..693ebdc413c43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala @@ -16,27 +16,40 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{SparkArithmeticException, SparkConf} import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class QueryContextSuite extends QueryTest with SharedSparkSession { + override def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "true") + + private val ansiConf = "\"" + SQLConf.ANSI_ENABLED.key + "\"" test("summary of DataFrame context") { - withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - SQLConf.STACK_TRACES_IN_DATAFRAME_CONTEXT.key -> "2") { + withSQLConf(SQLConf.STACK_TRACES_IN_DATAFRAME_CONTEXT.key -> "2") { val e = intercept[SparkArithmeticException] { spark.range(1).select(lit(1) / lit(0)).collect() } assert(e.getQueryContext.head.summary() == """== DataFrame == |"div" was called from - |org.apache.spark.sql.errors.QueryContextSuite.$anonfun$new$3(QueryContextSuite.scala:32) + |org.apache.spark.sql.errors.QueryContextSuite.$anonfun$new$3(QueryContextSuite.scala:33) |org.scalatest.Assertions.intercept(Assertions.scala:749) |""".stripMargin) } } + + test("SPARK-50290: Add a flag to disable DataFrame context") { + withSQLConf(SQLConf.DATA_FRAME_QUERY_CONTEXT_ENABLED.key -> "false") { + val df = spark.range(1).select(lit(1) / col("id")) + checkError( + exception = intercept[SparkArithmeticException](df.collect()), + condition = "DIVIDE_BY_ZERO", + parameters = Map("config" -> ansiConf), + context = ExpectedContext("", -1, -1) + ) + } + } } From 158aeb071147fac7d2c8400e37b07d21de51819e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Nov 2024 15:42:32 +0900 Subject: [PATCH 16/39] [SPARK-38912][PYTHON] Remove the comment related to classmethod and property ### What changes were proposed in this pull request? This PR proposes to remove the comment related to `classmethod` and `property`. ### Why are the changes needed? This is deprecated: https://docs.python.org/3.13/library/functions.html#classmethod ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI in this PR should verify it. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48832 from HyukjinKwon/SPARK-38912. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/session.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 748dd2cafa7c3..4979ce712673e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -139,15 +139,6 @@ def toDF(self, schema=None, sampleRatio=None): RDD.toDF = toDF # type: ignore[method-assign] -# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped -# In Python 3.9, the @property decorator has been made compatible with the -# @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) -# -# @classmethod + @property is also affected by a bug in Python's docstring which was backported -# to Python 3.9.6 (https://github.com/python/cpython/pull/28838) -# -# Python 3.9 with MyPy complains about @classmethod + @property combination. We should fix -# it together with MyPy. class classproperty(property): """Same as Python's @property decorator, but for class attributes. @@ -597,15 +588,6 @@ def create(self) -> "SparkSession": messageParameters={"feature": "SparkSession.builder.create"}, ) - # TODO(SPARK-38912): Replace classproperty with @classmethod + @property once support for - # Python 3.8 is dropped. - # - # In Python 3.9, the @property decorator has been made compatible with the - # @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) - # - # @classmethod + @property is also affected by a bug in Python's docstring which was backported - # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) - # # SPARK-47544: Explicitly declaring this as an identifier instead of a method. # If changing, make sure this bug is not reintroduced. builder: Builder = classproperty(lambda cls: cls.Builder()) # type: ignore From e29db6ecfa57d80e947f3d6fbd383ea29af27a43 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 13 Nov 2024 16:27:25 +0900 Subject: [PATCH 17/39] [SPARK-50283][INFRA] Add a separate docker file for linter ### What changes were proposed in this pull request? Add a separate docker file for linter ### Why are the changes needed? 1, to centralize the installation of linter; 2, to spin it off the single docker; ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48826 from zhengruifeng/infra_separate_docker_lint. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 39 +++++--- .../workflows/build_infra_images_cache.yml | 14 +++ dev/spark-test-image/lint/Dockerfile | 96 +++++++++++++++++++ python/pyspark/sql/connect/expressions.py | 2 +- python/pyspark/sql/pandas/conversion.py | 2 +- 5 files changed, 138 insertions(+), 15 deletions(-) create mode 100644 dev/spark-test-image/lint/Dockerfile diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 696bbf9cfe41c..fc0959c5a415a 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -60,6 +60,8 @@ jobs: image_url: ${{ steps.infra-image-outputs.outputs.image_url }} image_docs_url: ${{ steps.infra-image-docs-outputs.outputs.image_docs_url }} image_docs_url_link: ${{ steps.infra-image-link.outputs.image_docs_url_link }} + image_lint_url: ${{ steps.infra-image-lint-outputs.outputs.image_lint_url }} + image_lint_url_link: ${{ steps.infra-image-link.outputs.image_lint_url_link }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -144,6 +146,14 @@ jobs: IMG_NAME="apache-spark-ci-image-docs:${{ inputs.branch }}-${{ github.run_id }}" IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" echo "image_docs_url=$IMG_URL" >> $GITHUB_OUTPUT + - name: Generate infra image URL (Linter) + id: infra-image-lint-outputs + run: | + # Convert to lowercase to meet Docker repo name requirement + REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]') + IMG_NAME="apache-spark-ci-image-lint:${{ inputs.branch }}-${{ github.run_id }}" + IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME" + echo "image_lint_url=$IMG_URL" >> $GITHUB_OUTPUT - name: Link the docker images id: infra-image-link run: | @@ -151,8 +161,10 @@ jobs: # Should delete the link and directly use image_docs_url after SPARK 3.x EOL if [[ "${{ inputs.branch }}" == 'branch-3.5' ]]; then echo "image_docs_url_link=${{ steps.infra-image-outputs.outputs.image_url }}" >> $GITHUB_OUTPUT + echo "image_lint_url_link=${{ steps.infra-image-outputs.outputs.image_url }}" >> $GITHUB_OUTPUT else echo "image_docs_url_link=${{ steps.infra-image-docs-outputs.outputs.image_docs_url }}" >> $GITHUB_OUTPUT + echo "image_lint_url_link=${{ steps.infra-image-lint-outputs.outputs.image_lint_url }}" >> $GITHUB_OUTPUT fi # Build: build Spark and run the tests for specified modules. @@ -382,6 +394,17 @@ jobs: ${{ needs.precondition.outputs.image_docs_url }} # Use the infra image cache to speed up cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-docs-cache:${{ inputs.branch }} + - name: Build and push (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + id: docker_build_lint + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/lint/ + push: true + tags: | + ${{ needs.precondition.outputs.image_lint_url }} + # Use the infra image cache to speed up + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ inputs.branch }} pyspark: @@ -667,7 +690,7 @@ jobs: PYSPARK_PYTHON: python3.9 GITHUB_PREV_SHA: ${{ github.event.before }} container: - image: ${{ needs.precondition.outputs.image_url }} + image: ${{ needs.precondition.outputs.image_lint_url_link }} steps: - name: Checkout Spark repository uses: actions/checkout@v4 @@ -741,18 +764,8 @@ jobs: # Should delete this section after SPARK 3.5 EOL. python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.982' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.56.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - - name: Install Python dependencies for python linter and documentation generation - if: inputs.branch != 'branch-3.5' - run: | - # Should unpin 'sphinxcontrib-*' after upgrading sphinx>5 - # See 'ipython_genutils' in SPARK-38517 - # See 'docutils<0.18.0' in SPARK-39421 - python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive numpy pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ - 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ - 'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ - 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' - python3.9 -m pip list + - name: List Python packages + run: python3.9 -m pip list - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python # Should delete this section after SPARK 3.5 EOL. diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index 3d7b3cc71f25b..b82d0633b0cee 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -28,6 +28,7 @@ on: paths: - 'dev/infra/Dockerfile' - 'dev/spark-test-image/docs/Dockerfile' + - 'dev/spark-test-image/lint/Dockerfile' - '.github/workflows/build_infra_images_cache.yml' # Create infra image when cutting down branches/tags create: @@ -74,3 +75,16 @@ jobs: - name: Image digest (Documentation) if: hashFiles('dev/spark-test-image/docs/Dockerfile') != '' run: echo ${{ steps.docker_build_docs.outputs.digest }} + - name: Build and push (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + id: docker_build_lint + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/lint/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-lint-cache:${{ github.ref_name }},mode=max + - name: Image digest (Linter) + if: hashFiles('dev/spark-test-image/lint/Dockerfile') != '' + run: echo ${{ steps.docker_build_lint.outputs.digest }} diff --git a/dev/spark-test-image/lint/Dockerfile b/dev/spark-test-image/lint/Dockerfile new file mode 100644 index 0000000000000..f9ea3124291b1 --- /dev/null +++ b/dev/spark-test-image/lint/Dockerfile @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image for Linter" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE 20241112 + +ENV DEBIAN_FRONTEND noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + libpng-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + nodejs \ + npm \ + pkg-config \ + qpdf \ + r-base \ + software-properties-common \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', 'rmarkdown', 'testthat'), repos='https://cloud.r-project.org/')" \ + && Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" \ + && Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" \ + && Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" \ + +# See more in SPARK-39735 +ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" + +# Install Python 3.9 +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y python3.9 python3.9-distutils \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 + +RUN python3.9 -m pip install \ + 'black==23.9.1' \ + 'flake8==3.9.0' \ + 'googleapis-common-protos-stubs==2.2.0' \ + 'grpc-stubs==1.24.11' \ + 'grpcio-status==1.67.0' \ + 'grpcio==1.67.0' \ + 'ipython' \ + 'ipython_genutils' \ + 'jinja2' \ + 'matplotlib' \ + 'mypy==1.8.0' \ + 'numpy==2.0.2' \ + 'numpydoc' \ + 'pandas' \ + 'pandas-stubs==1.2.0.53' \ + 'plotly>=4.8' \ + 'pyarrow>=18.0.0' \ + 'pytest-mypy-plugins==1.9.3' \ + 'pytest==7.1.3' \ + && python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu \ + && python3.9 -m pip install torcheval \ + && python3.9 -m pip cache purge diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 4915078af0225..5a5320366f666 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -490,7 +490,7 @@ def __repr__(self) -> str: # is sightly different: # java.time.Duration only applies HOURS, MINUTES, SECONDS units, # while Pandas applies all supported units. - return pd.Timedelta(delta).isoformat() # type: ignore[attr-defined] + return pd.Timedelta(delta).isoformat() return f"{self._value}" diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 6f305b5b73b6c..172a4fc4b2343 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -520,7 +520,7 @@ def convert_timestamp(value: Any) -> Any: else: return ( pd.Timestamp(value) - .tz_localize(timezone, ambiguous=False) # type: ignore + .tz_localize(timezone, ambiguous=False) .tz_convert(_get_local_timezone()) .tz_localize(None) .to_pydatetime() From cc84e7f1824f4712a3da7167f4284cabc004593e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Nov 2024 16:37:32 +0900 Subject: [PATCH 18/39] [SPARK-50296][PYTHON][CONNECT] Avoid using a classproperty in threadpool for Python Connect client ### What changes were proposed in this pull request? This PR proposes to avoid using `classmethod` and `property` combination in Python Connect client ### Why are the changes needed? In order to fix up the test failure at https://github.com/apache/spark/actions/runs/11804766326/job/32885813371. https://docs.python.org/3.13/library/functions.html#classmethod Seems like this combination was deprecated in Python 3.11, and removed in 3.13. ### Does this PR introduce _any_ user-facing change? Yes, it properly supports Python 3.13 with Python Connect client. ### How was this patch tested? Manually tested locally. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48831 from HyukjinKwon/SPARK-50296. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/reattach.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index e6dba6e0073f7..91b7aa125920d 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -60,9 +60,8 @@ class ExecutePlanResponseReattachableIterator(Generator): _lock: ClassVar[RLock] = RLock() _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None - @classmethod # type: ignore[misc] - @property - def _release_thread_pool(cls) -> ThreadPoolExecutor: + @classmethod + def _get_or_create_release_thread_pool(cls) -> ThreadPoolExecutor: # Perform a first check outside the critical path. if cls._release_thread_pool_instance is not None: return cls._release_thread_pool_instance @@ -80,7 +79,7 @@ def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: """ with cls._lock: if cls._release_thread_pool_instance is not None: - cls._release_thread_pool.shutdown() # type: ignore[attr-defined] + cls._get_or_create_release_thread_pool().shutdown() cls._release_thread_pool_instance = None def __init__( @@ -130,6 +129,10 @@ def __init__( # Current item from this iterator. self._current: Optional[pb2.ExecutePlanResponse] = None + @property + def _release_thread_pool(self) -> ThreadPoolExecutor: + return self._get_or_create_release_thread_pool() + def send(self, value: Any) -> pb2.ExecutePlanResponse: # will trigger reattach in case the stream completed without result_complete if not self._has_next(): From 82040bbd41cc236ea2c41b5a654a4f59b7a19a3c Mon Sep 17 00:00:00 2001 From: Chenghao Lyu Date: Wed, 13 Nov 2024 16:00:54 +0800 Subject: [PATCH 19/39] [SPARK-49563][SQL] Add SQL pipe syntax for the WINDOW operator ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for the WINDOW clause within the pipe SELECT operator `|> SELECT`. For example ```sparksql CREATE TEMPORARY VIEW t AS SELECT * from VALUES (1, 'apple', 1), (2, 'banana', 2), (3, 'apple', 3), (4, 'banana', 4), AS t(id, name, amount); TABLE t |> SELECT id, name, amount, SUM(amount) OVER w WINDOW w AS (PARTITION BY name ORDER BY id); 1, apple, 1, 3 3, apple, 2, 3 2, banana, 3, 7 4, banana, 4, 7 ``` Notes: 1. `|> WHERE` is not extended to use the WINDOW clause because `|> WHERE` does not support window functions in its expressions. ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? Yes ### Was this patch authored or co-authored using generative AI tooling? No Closes #48649 from Angryrou/pipe-window. Authored-by: Chenghao Lyu Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 13 + .../sql/catalyst/parser/SqlBaseParser.g4 | 7 +- .../spark/sql/errors/QueryParsingErrors.scala | 4 + .../sql/catalyst/analysis/Analyzer.scala | 18 +- .../sql/catalyst/parser/AstBuilder.scala | 23 +- .../plans/logical/basicLogicalOperators.scala | 3 +- .../sql/catalyst/parser/PlanParserSuite.scala | 7 +- .../analyzer-results/pipe-operators.sql.out | 294 ++++++++++++++--- .../sql-tests/inputs/pipe-operators.sql | 97 +++++- .../sql-tests/results/pipe-operators.sql.out | 304 +++++++++++++++--- .../sql/execution/SparkSqlParserSuite.scala | 12 +- 11 files changed, 656 insertions(+), 126 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a8e60d1850e2e..cc31678bc1ec9 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3638,6 +3638,19 @@ }, "sqlState" : "42601" }, + "NOT_ALLOWED_IN_PIPE_OPERATOR_WHERE" : { + "message" : [ + "Not allowed in the pipe WHERE clause:" + ], + "subClass" : { + "WINDOW_CLAUSE" : { + "message" : [ + "WINDOW clause." + ] + } + }, + "sqlState" : "42601" + }, "NOT_A_CONSTANT_STRING" : { "message" : [ "The expression used for the routine or clause must be a constant STRING which is NOT NULL." diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 4900c971966cc..0d049f4b18e0e 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1502,8 +1502,11 @@ version ; operatorPipeRightSide - : selectClause - | whereClause + : selectClause windowClause? + // Note that the WINDOW clause is not allowed in the WHERE pipe operator, but we add it here in + // the grammar simply for purposes of catching this invalid syntax and throwing a specific + // dedicated error message. + | whereClause windowClause? // The following two cases match the PIVOT or UNPIVOT clause, respectively. // For each one, we add the other clause as an option in order to return high-quality error // messages in the event that both are present (this is not allowed). diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 89ca45fa51256..0bd9f38014984 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -114,6 +114,10 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx) } + def windowClauseInPipeOperatorWhereClauseNotAllowedError(ctx: ParserRuleContext): Throwable = { + new ParseException(errorClass = "NOT_ALLOWED_IN_PIPE_OPERATOR_WHERE.WINDOW_CLAUSE", ctx) + } + def distributeByUnsupportedError(ctx: QueryOrganizationContext): Throwable = { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0012", ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2e82d7ad39c45..d1d04d4117263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -425,12 +425,18 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(WITH_WINDOW_DEFINITION, UNRESOLVED_WINDOW_EXPRESSION), ruleId) { // Lookup WindowSpecDefinitions. This rule works with unresolved children. - case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { - case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => - val windowSpecDefinition = windowDefinitions.getOrElse(windowName, - throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowName)) - WindowExpression(c, windowSpecDefinition) - } + case WithWindowDefinition(windowDefinitions, child, forPipeSQL) => + val resolveWindowExpression: PartialFunction[Expression, Expression] = { + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => + val windowSpecDefinition = windowDefinitions.getOrElse(windowName, + throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowName)) + WindowExpression(c, windowSpecDefinition) + } + if (forPipeSQL) { + child.transformExpressions(resolveWindowExpression) + } else { + child.resolveExpressions(resolveWindowExpression) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 044e945d16ad1..0c03ce9bed118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -979,9 +979,10 @@ class AstBuilder extends DataTypeAstBuilder * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These * clauses determine the shape (ordering/partitioning/rows) of the query result. * - * If 'forPipeOperators' is true, throws an error if the WINDOW clause is present (since this is - * not currently supported) or if more than one clause is present (this can be useful when parsing - * clauses used with pipe operations which only allow one instance of these clauses each). + * If 'forPipeOperators' is true, throws an error if the WINDOW clause is present (since it breaks + * the composability of the pipe operators) or if more than one clause is present (this can be + * useful when parsing clauses used with pipe operations which only allow one instance of these + * clauses each). */ private def withQueryResultClauses( ctx: QueryOrganizationContext, @@ -1023,7 +1024,7 @@ class AstBuilder extends DataTypeAstBuilder // WINDOWS val withWindow = withOrder.optionalMap(windowClause) { - withWindowClause + withWindowClause(_, _, forPipeOperators) } if (forPipeOperators && windowClause != null) { throw QueryParsingErrors.clausesWithPipeOperatorsUnsupportedError( @@ -1306,7 +1307,9 @@ class AstBuilder extends DataTypeAstBuilder } // Window - val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) + val withWindow = withDistinct.optionalMap(windowClause) { + withWindowClause(_, _, isPipeOperatorSelect) + } withWindow } @@ -1463,7 +1466,8 @@ class AstBuilder extends DataTypeAstBuilder */ private def withWindowClause( ctx: WindowClauseContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + query: LogicalPlan, + forPipeSQL: Boolean): LogicalPlan = withOrigin(ctx) { // Collect all window specifications defined in the WINDOW clause. val baseWindowTuples = ctx.namedWindow.asScala.map { wCtx => @@ -1495,7 +1499,7 @@ class AstBuilder extends DataTypeAstBuilder // Note that mapValues creates a view instead of materialized map. We force materialization by // mapping over identity. - WithWindowDefinition(windowMapView.map(identity), query) + WithWindowDefinition(windowMapView.map(identity), query, forPipeSQL) } /** @@ -5894,10 +5898,13 @@ class AstBuilder extends DataTypeAstBuilder whereClause = null, aggregationClause = null, havingClause = null, - windowClause = null, + windowClause = ctx.windowClause, relation = left, isPipeOperatorSelect = true) }.getOrElse(Option(ctx.whereClause).map { c => + if (ctx.windowClause() != null) { + throw QueryParsingErrors.windowClauseInPipeOperatorWhereClauseNotAllowedError(ctx) + } withWhereClause(c, withSubqueryAlias()) }.getOrElse(Option(ctx.pivotClause()).map { c => if (ctx.unpivotClause() != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 15f52e856bef0..dc286183ac689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -929,7 +929,8 @@ trait CTEInChildren extends LogicalPlan { case class WithWindowDefinition( windowDefinitions: Map[String, WindowSpecDefinition], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + forPipeSQL: Boolean) extends UnaryNode { override def output: Seq[Attribute] = child.output final override val nodePatterns: Seq[TreePattern] = Seq(WITH_WINDOW_DEFINITION) override protected def withNewChildInternal(newChild: LogicalPlan): WithWindowDefinition = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fdae3863f46fb..c556a92373954 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -372,8 +372,9 @@ class PlanParserSuite extends AnalysisTest { val limitWindowClauses = Seq( ("", (p: LogicalPlan) => p), (" limit 10", (p: LogicalPlan) => p.limit(10)), - (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), - (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p, forPipeSQL = false)), + (" window w1 as () limit 10", (p: LogicalPlan) => + WithWindowDefinition(ws, p, forPipeSQL = false).limit(10)) ) val orderSortDistrClusterClauses = Seq( @@ -524,7 +525,7 @@ class PlanParserSuite extends AnalysisTest { |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), | w2 as w1, | w3 as w1""".stripMargin, - WithWindowDefinition(ws1, plan)) + WithWindowDefinition(ws1, plan, forPipeSQL = false)) } test("lateral view") { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index 6af64b116f7d0..47eb8f2417381 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -704,7 +704,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query table t -|> where first_value(x) over (partition by y) = 1 +|> where sum(x) over (partition by y) = 1 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -716,7 +716,26 @@ org.apache.spark.sql.AnalysisException -- !query -select * from t where first_value(x) over (partition by y) = 1 +table t +|> where sum(x) over w = 1 + window w as (partition by y) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_PIPE_OPERATOR_WHERE.WINDOW_CLAUSE", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 66, + "fragment" : "table t\n|> where sum(x) over w = 1\n window w as (partition by y)" + } ] +} + + +-- !query +select * from t where sum(x) over (partition by y) = 1 -- !query analysis org.apache.spark.sql.AnalysisException { @@ -2274,49 +2293,6 @@ org.apache.spark.sql.catalyst.parser.ParseException } --- !query -table windowTestData -|> window w as (partition by cte order by val) -|> select cate, sum(val) over w --- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", - "sqlState" : "0A000", - "messageParameters" : { - "clauses" : "the WINDOW clause" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 25, - "stopIndex" : 67, - "fragment" : "window w as (partition by cte order by val)" - } ] -} - - --- !query -table windowTestData -|> window w as (partition by cate order by val) limit 5 --- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", - "sqlState" : "0A000", - "messageParameters" : { - "clauses" : "the WINDOW clause" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 25, - "stopIndex" : 76, - "fragment" : "window w as (partition by cate order by val) limit 5" - } ] -} - - -- !query table other |> aggregate sum(b) as result group by a @@ -2769,6 +2745,234 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table windowTestData +|> select cate, sum(val) over w + window w as (partition by cate order by val) +-- !query analysis +Project [cate#x, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [cate#x, val#x, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, sum(val) over w + window w as (order by val_timestamp range between unbounded preceding and current row) +-- !query analysis +Project [cate#x, sum(val) OVER (ORDER BY val_timestamp ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [cate#x, val#x, val_timestamp#x, sum(val) OVER (ORDER BY val_timestamp ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(val) OVER (ORDER BY val_timestamp ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(val#x) windowspecdefinition(val_timestamp#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(val) OVER (ORDER BY val_timestamp ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [val_timestamp#x ASC NULLS FIRST] + +- Project [cate#x, val#x, val_timestamp#x] + +- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, val + window w as (partition by cate order by val) +-- !query analysis +Project [cate#x, val#x] ++- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate) +|> select cate, val, sum_val, first_value(cate) over w + window w as (order by val) +-- !query analysis +Project [cate#x, val#x, sum_val#xL, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] ++- Project [cate#x, val#x, sum_val#xL, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] + +- Window [first_value(cate#x, false) windowspecdefinition(val#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x, sum_val#xL] + +- Project [cate#x, val#x, sum_val#xL] + +- Project [cate#x, val#x, _we0#xL, pipeselect(_we0#xL) AS sum_val#xL] + +- Window [sum(val#x) windowspecdefinition(cate#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#xL], [cate#x] + +- Project [cate#x, val#x] + +- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate), w2 as (order by val) +-- !query analysis +Project [cate#x, val#x, sum(val) OVER (PARTITION BY cate ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] ++- Project [cate#x, val#x, sum(val) OVER (PARTITION BY cate ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x, sum(val) OVER (PARTITION BY cate ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL, first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x] + +- Window [first_value(cate#x, false) windowspecdefinition(val#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS first_value(cate) OVER (ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x], [val#x ASC NULLS FIRST] + +- Window [sum(val#x) windowspecdefinition(cate#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS sum(val) OVER (PARTITION BY cate ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#xL], [cate#x] + +- Project [cate#x, val#x] + +- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w, first_value(val) over w + window w1 as (partition by cate order by val) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +(select col from st) +|> select col.i1, sum(col.i2) over w + window w as (partition by col.i1 order by col.i2) +-- !query analysis +Project [i1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [i1#x, _w0#x, _w1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [_w1#x], [_w0#x ASC NULLS FIRST] + +- Project [col#x.i1 AS i1#x, col#x.i2 AS _w0#x, col#x.i1 AS _w1#x] + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> select st.col.i1, sum(st.col.i2) over w + window w as (partition by st.col.i1 order by st.col.i2) +-- !query analysis +Project [i1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [i1#x, _w0#x, _w1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [_w1#x], [_w0#x ASC NULLS FIRST] + +- Project [col#x.i1 AS i1#x, col#x.i2 AS _w0#x, col#x.i1 AS _w1#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> select spark_catalog.default.st.col.i1, sum(spark_catalog.default.st.col.i2) over w + window w as (partition by spark_catalog.default.st.col.i1 order by spark_catalog.default.st.col.i2) +-- !query analysis +Project [i1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [i1#x, _w0#x, _w1#x, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(_w0#x) windowspecdefinition(_w1#x, _w0#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(col.i2) OVER (PARTITION BY col.i1 ORDER BY col.i2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [_w1#x], [_w0#x ASC NULLS FIRST] + +- Project [col#x.i1 AS i1#x, col#x.i2 AS _w0#x, col#x.i1 AS _w1#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table windowTestData +|> select cate, sum(val) over val + window val as (partition by cate order by val) +-- !query analysis +Project [cate#x, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] ++- Project [cate#x, val#x, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL] + +- Window [sum(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(val) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias windowtestdata + +- View (`windowTestData`, [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +table windowTestData +|> select cate, sum(val) over w +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate) + window w2 as (order by val) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w2" + } +} + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate order by val) +|> select cate, val, sum_val, first_value(cate) over w +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +table windowTestData +|> select cate, val, first_value(cate) over w as first_val +|> select cate, val, sum(val) over w as sum_val + window w as (order by val) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 8de22e65f0fb5..8bca7144c0a98 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -228,10 +228,17 @@ table t |> where y = 'abc' or length(y) + sum(x) = 1; -- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +-- (Note: to implement this behavior, perform the window function first separately in a SELECT +-- clause and then add a pipe-operator WHERE clause referring to the result of the window function +-- expression(s) therein). table t -|> where first_value(x) over (partition by y) = 1; +|> where sum(x) over (partition by y) = 1; -select * from t where first_value(x) over (partition by y) = 1; +table t +|> where sum(x) over w = 1 + window w as (partition by y); + +select * from t where sum(x) over (partition by y) = 1; -- Pipe operators may only refer to attributes produced as output from the directly-preceding -- pipe operator, not from earlier ones. @@ -665,15 +672,6 @@ table t table t |> order by x sort by x; --- The WINDOW clause is not supported yet. -table windowTestData -|> window w as (partition by cte order by val) -|> select cate, sum(val) over w; - --- WINDOW and LIMIT are not supported at the same time. -table windowTestData -|> window w as (partition by cate order by val) limit 5; - -- Aggregation operators: positive tests. ----------------------------------------- @@ -821,6 +819,83 @@ select 1 x, 2 y, 3 z table other |> aggregate b group by a; +-- WINDOW operators (within SELECT): positive tests. +--------------------------------------------------- + +-- SELECT with a WINDOW clause. +table windowTestData +|> select cate, sum(val) over w + window w as (partition by cate order by val); + +-- SELECT with RANGE BETWEEN as part of the window definition. +table windowTestData +|> select cate, sum(val) over w + window w as (order by val_timestamp range between unbounded preceding and current row); + +-- SELECT with a WINDOW clause not being referred in the SELECT list. +table windowTestData +|> select cate, val + window w as (partition by cate order by val); + +-- multiple SELECT clauses, each with a WINDOW clause (with the same window definition names). +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate) +|> select cate, val, sum_val, first_value(cate) over w + window w as (order by val); + +-- SELECT with a WINDOW clause for multiple window definitions. +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate), w2 as (order by val); + +-- SELECT with a WINDOW clause for multiple window functions over one window definition +table windowTestData +|> select cate, val, sum(val) over w, first_value(val) over w + window w1 as (partition by cate order by val); + +-- SELECT with a WINDOW clause, using struct fields. +(select col from st) +|> select col.i1, sum(col.i2) over w + window w as (partition by col.i1 order by col.i2); + +table st +|> select st.col.i1, sum(st.col.i2) over w + window w as (partition by st.col.i1 order by st.col.i2); + +table st +|> select spark_catalog.default.st.col.i1, sum(spark_catalog.default.st.col.i2) over w + window w as (partition by spark_catalog.default.st.col.i1 order by spark_catalog.default.st.col.i2); + +-- SELECT with one WINDOW definition shadowing a column name. +table windowTestData +|> select cate, sum(val) over val + window val as (partition by cate order by val); + +-- WINDOW operators (within SELECT): negative tests. +--------------------------------------------------- + +-- WINDOW without definition is not allowed in the pipe operator SELECT clause. +table windowTestData +|> select cate, sum(val) over w; + +-- Multiple WINDOW clauses are not supported in the pipe operator SELECT clause. +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate) + window w2 as (order by val); + +-- WINDOW definition cannot be referred across different pipe operator SELECT clauses. +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate order by val) +|> select cate, val, sum_val, first_value(cate) over w; + +table windowTestData +|> select cate, val, first_value(cate) over w as first_val +|> select cate, val, sum(val) over w as sum_val + window w as (order by val); + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 8ad2def84082e..aae68dddbaab3 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -637,7 +637,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query table t -|> where first_value(x) over (partition by y) = 1 +|> where sum(x) over (partition by y) = 1 -- !query schema struct<> -- !query output @@ -651,7 +651,28 @@ org.apache.spark.sql.AnalysisException -- !query -select * from t where first_value(x) over (partition by y) = 1 +table t +|> where sum(x) over w = 1 + window w as (partition by y) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_PIPE_OPERATOR_WHERE.WINDOW_CLAUSE", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 66, + "fragment" : "table t\n|> where sum(x) over w = 1\n window w as (partition by y)" + } ] +} + + +-- !query +select * from t where sum(x) over (partition by y) = 1 -- !query schema struct<> -- !query output @@ -1915,53 +1936,6 @@ org.apache.spark.sql.catalyst.parser.ParseException } --- !query -table windowTestData -|> window w as (partition by cte order by val) -|> select cate, sum(val) over w --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", - "sqlState" : "0A000", - "messageParameters" : { - "clauses" : "the WINDOW clause" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 25, - "stopIndex" : 67, - "fragment" : "window w as (partition by cte order by val)" - } ] -} - - --- !query -table windowTestData -|> window w as (partition by cate order by val) limit 5 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "UNSUPPORTED_FEATURE.CLAUSE_WITH_PIPE_OPERATORS", - "sqlState" : "0A000", - "messageParameters" : { - "clauses" : "the WINDOW clause" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 25, - "stopIndex" : 76, - "fragment" : "window w as (partition by cate order by val) limit 5" - } ] -} - - -- !query table other |> aggregate sum(b) as result group by a @@ -2442,6 +2416,238 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table windowTestData +|> select cate, sum(val) over w + window w as (partition by cate order by val) +-- !query schema +struct +-- !query output +NULL 3 +NULL NULL +a 2 +a 2 +a 4 +a NULL +b 1 +b 3 +b 6 + + +-- !query +table windowTestData +|> select cate, sum(val) over w + window w as (order by val_timestamp range between unbounded preceding and current row) +-- !query schema +struct +-- !query output +NULL 5 +NULL NULL +a 13 +a 5 +a 5 +a 6 +b 13 +b 5 +b 8 + + +-- !query +table windowTestData +|> select cate, val + window w as (partition by cate order by val) +-- !query schema +struct +-- !query output +NULL 3 +NULL NULL +a 1 +a 1 +a 2 +a NULL +b 1 +b 2 +b 3 + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate) +|> select cate, val, sum_val, first_value(cate) over w + window w as (order by val) +-- !query schema +struct +-- !query output +NULL 3 3 a +NULL NULL 3 a +a 1 4 a +a 1 4 a +a 2 4 a +a NULL 4 a +b 1 6 a +b 2 6 a +b 3 6 a + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate), w2 as (order by val) +-- !query schema +struct +-- !query output +NULL 3 3 a +NULL NULL 3 a +a 1 4 a +a 1 4 a +a 2 4 a +a NULL 4 a +b 1 6 a +b 2 6 a +b 3 6 a + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w, first_value(val) over w + window w1 as (partition by cate order by val) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +(select col from st) +|> select col.i1, sum(col.i2) over w + window w as (partition by col.i1 order by col.i2) +-- !query schema +struct +-- !query output +2 3 + + +-- !query +table st +|> select st.col.i1, sum(st.col.i2) over w + window w as (partition by st.col.i1 order by st.col.i2) +-- !query schema +struct +-- !query output +2 3 + + +-- !query +table st +|> select spark_catalog.default.st.col.i1, sum(spark_catalog.default.st.col.i2) over w + window w as (partition by spark_catalog.default.st.col.i1 order by spark_catalog.default.st.col.i2) +-- !query schema +struct +-- !query output +2 3 + + +-- !query +table windowTestData +|> select cate, sum(val) over val + window val as (partition by cate order by val) +-- !query schema +struct +-- !query output +NULL 3 +NULL NULL +a 2 +a 2 +a 4 +a NULL +b 1 +b 3 +b 6 + + +-- !query +table windowTestData +|> select cate, sum(val) over w +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w1, first_value(cate) over w2 + window w1 as (partition by cate) + window w2 as (order by val) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w2" + } +} + + +-- !query +table windowTestData +|> select cate, val, sum(val) over w as sum_val + window w as (partition by cate order by val) +|> select cate, val, sum_val, first_value(cate) over w +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + +-- !query +table windowTestData +|> select cate, val, first_value(cate) over w as first_val +|> select cate, val, sum(val) over w as sum_val + window w as (order by val) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_WINDOW_SPECIFICATION", + "sqlState" : "42P20", + "messageParameters" : { + "docroot" : "https://spark.apache.org/docs/latest", + "windowName" : "w" + } +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 577994a1e0cb2..357fd8beb961a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -673,7 +673,9 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { UnresolvedFunction("max", Seq(UnresolvedAttribute("c")), isDistinct = false), WindowSpecReference("w")), None) ), - UnresolvedRelation(TableIdentifier("testData")))), + UnresolvedRelation(TableIdentifier("testData"))), + forPipeSQL = false + ), ioSchema)) assertEqual( @@ -972,6 +974,14 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkAggregate("SELECT a, b FROM t |> AGGREGATE SUM(a) AS result GROUP BY b") checkAggregate("SELECT a, b FROM t |> AGGREGATE GROUP BY b") checkAggregate("SELECT a, b FROM t |> AGGREGATE COUNT(*) AS result GROUP BY b") + // Window + def checkWindow(query: String): Unit = check(query, Seq(WITH_WINDOW_DEFINITION)) + checkWindow( + """ + |TABLE windowTestData + ||> SELECT cate, SUM(val) OVER w + | WINDOW w AS (PARTITION BY cate ORDER BY val) + |""".stripMargin) } } } From 40ffdefc2e11f9605e9bc84f3f8f3b57cb57f0d4 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 13 Nov 2024 09:15:13 +0100 Subject: [PATCH 20/39] [SPARK-50250][SQL] Assign appropriate error condition for `_LEGACY_ERROR_TEMP_2075`: `UNSUPPORTED_FEATURE.WRITE_FOR_BINARY_SOURCE` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_2075 ` into `UNSUPPORTED_FEATURE.WRITE_FOR_BINARY_SOURCE ` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48780 from itholic/LEGACY_2075. Lead-authored-by: Haejoon Lee Co-authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 10 +++++----- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 2 +- .../datasources/binaryfile/BinaryFileFormatSuite.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index cc31678bc1ec9..b3c92a9f2b9d1 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5382,6 +5382,11 @@ "message" : [ "Update column nullability for MySQL and MS SQL Server." ] + }, + "WRITE_FOR_BINARY_SOURCE" : { + "message" : [ + "Write for the binary file data source." + ] } }, "sqlState" : "0A000" @@ -7083,11 +7088,6 @@ "user-specified schema." ] }, - "_LEGACY_ERROR_TEMP_2075" : { - "message" : [ - "Write is not supported for binary file data source." - ] - }, "_LEGACY_ERROR_TEMP_2076" : { "message" : [ "The length of is , which exceeds the max length allowed: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 6cf930f18dc2d..0aa21a4d79c78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -914,7 +914,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def writeUnsupportedForBinaryFileDataSourceError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2075") + new SparkUnsupportedOperationException("UNSUPPORTED_FEATURE.WRITE_FOR_BINARY_SOURCE") } def fileLengthExceedsMaxLengthError(status: FileStatus, maxLength: Int): Throwable = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 387a2baa256bf..62f2f2cb10a85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -168,7 +168,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { .format(BINARY_FILE) .save(s"$tmpDir/test_save") }, - condition = "_LEGACY_ERROR_TEMP_2075", + condition = "UNSUPPORTED_FEATURE.WRITE_FOR_BINARY_SOURCE", parameters = Map.empty) } } From ede05fa500feb02be23f89f37e8b29265ddfc5cc Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 13 Nov 2024 09:18:06 +0100 Subject: [PATCH 21/39] [SPARK-50248][SQL] Assign appropriate error condition for `_LEGACY_ERROR_TEMP_2058`: `INVALID_PARTITION_VALUE` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_2058 ` into `INVALID_PARTITION_VALUE ` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48778 from itholic/LEGACY_2058. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/errors/QueryExecutionErrors.scala | 8 ++++---- .../sql/execution/datasources/FileIndexSuite.scala | 4 ++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index b3c92a9f2b9d1..553d085f88627 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2950,6 +2950,12 @@ }, "sqlState" : "42601" }, + "INVALID_PARTITION_VALUE" : { + "message" : [ + "Failed to cast value to data type for partition column . Ensure the value matches the expected data type for this partition column." + ], + "sqlState" : "42846" + }, "INVALID_PROPERTY_KEY" : { "message" : [ " is an invalid property key, please use quotes, e.g. SET =." @@ -7026,11 +7032,6 @@ "Unable to clear partition directory prior to writing to it." ] }, - "_LEGACY_ERROR_TEMP_2058" : { - "message" : [ - "Failed to cast value `` to `` for partition column ``." - ] - }, "_LEGACY_ERROR_TEMP_2059" : { "message" : [ "End of stream." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 0aa21a4d79c78..0e3f37d8d6fb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -795,11 +795,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def failedToCastValueToDataTypeForPartitionColumnError( value: String, dataType: DataType, columnName: String): SparkRuntimeException = { new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2058", + errorClass = "INVALID_PARTITION_VALUE", messageParameters = Map( - "value" -> value, - "dataType" -> dataType.toString(), - "columnName" -> columnName)) + "value" -> toSQLValue(value), + "dataType" -> toSQLType(dataType), + "columnName" -> toSQLId(columnName))) } def endOfStreamError(): Throwable = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index e9f78f9f598e1..33b4cc1d2e7f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -137,8 +137,8 @@ class FileIndexSuite extends SharedSparkSession { exception = intercept[SparkRuntimeException] { fileIndex.partitionSpec() }, - condition = "_LEGACY_ERROR_TEMP_2058", - parameters = Map("value" -> "foo", "dataType" -> "IntegerType", "columnName" -> "a") + condition = "INVALID_PARTITION_VALUE", + parameters = Map("value" -> "'foo'", "dataType" -> "\"INT\"", "columnName" -> "`a`") ) } From 6fb1d438191262d2a127bc72cbbb1127fcac7587 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 13 Nov 2024 09:20:39 +0100 Subject: [PATCH 22/39] [SPARK-50246][SQL] Assign appropriate error condition for `_LEGACY_ERROR_TEMP_2167`: `INVALID_JSON_RECORD_TYPE` ### What changes were proposed in this pull request? This PR proposes to Integrate `_LEGACY_ERROR_TEMP_2167 ` into `INVALID_JSON_RECORD_TYPE ` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48775 from itholic/LEGACY_2167. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/errors/QueryExecutionErrors.scala | 4 ++-- .../sql/execution/datasources/json/JsonSuite.scala | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 553d085f88627..e51b35c0accc2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2645,6 +2645,12 @@ ], "sqlState" : "2203G" }, + "INVALID_JSON_RECORD_TYPE" : { + "message" : [ + "Detected an invalid type of a JSON record while inferring a common schema in the mode . Expected a STRUCT type, but found ." + ], + "sqlState" : "22023" + }, "INVALID_JSON_ROOT_FIELD" : { "message" : [ "Cannot convert JSON root field to target Spark type." @@ -7354,11 +7360,6 @@ "Malformed JSON." ] }, - "_LEGACY_ERROR_TEMP_2167" : { - "message" : [ - "Malformed records are detected in schema inference. Parse Mode: . Reasons: Failed to infer a common schema. Struct types are expected, but `` was found." - ] - }, "_LEGACY_ERROR_TEMP_2168" : { "message" : [ "Decorrelate inner query through is not supported." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 0e3f37d8d6fb5..09836995925ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1437,10 +1437,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def malformedRecordsDetectedInSchemaInferenceError(dataType: DataType): Throwable = { new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2167", + errorClass = "INVALID_JSON_RECORD_TYPE", messageParameters = Map( "failFastMode" -> FailFastMode.name, - "dataType" -> dataType.catalogString), + "invalidType" -> toSQLType(dataType)), cause = null) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 06183596a54ad..dfbc8e5279aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2078,8 +2078,8 @@ abstract class JsonSuite .option("mode", "FAILFAST") .json(path) }, - condition = "_LEGACY_ERROR_TEMP_2167", - parameters = Map("failFastMode" -> "FAILFAST", "dataType" -> "string|bigint")) + condition = "INVALID_JSON_RECORD_TYPE", + parameters = Map("failFastMode" -> "FAILFAST", "invalidType" -> "\"STRING\"|\"BIGINT\"")) val ex = intercept[SparkException] { spark.read From 898bff21c9921ba40c10ed19034baade5e0ac543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Wed, 13 Nov 2024 09:30:30 +0100 Subject: [PATCH 23/39] [SPARK-50245][SQL][TESTS] Extended CollationSuite and added tests where SortMergeJoin is forced MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? I propose extending existing tests in `CollationSuite` and add cases where `SortMergeJoin` is forced and tested for correctness and use of `CollationKey`. ### Why are the changes needed? These changes are needed to properly test behavior of join with collated data when different configs are enabled. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The change is a test itself. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48774 from vladanvasi-db/vladanvasi-db/collation-suite-test-extension. Authored-by: Vladan Vasić Signed-off-by: Max Gekk --- .../org/apache/spark/sql/CollationSuite.scala | 362 +++++++++--------- 1 file changed, 187 insertions(+), 175 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9a47491b0cca4..9716d342bb6bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -43,6 +44,29 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val collationNonPreservingSources = Seq("orc", "csv", "json", "text") private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources + @inline + private def isSortMergeForced: Boolean = { + SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) == -1 + } + + private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = { + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => assert(isSortMergeForced) + case _: HashJoin => assert(!isSortMergeForced) + }.nonEmpty + ) + } + + private def checkCollationKeyInQueryPlan(queryPlan: SparkPlan, collationName: String): Unit = { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(collationName).supportsBinaryEquality) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + test("collate returns proper type") { Seq( "utf8_binary", @@ -1419,7 +1443,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).isUtf8BinaryType + CollationFactory.fetchCollation(collation).supportsBinaryEquality test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1589,7 +1613,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash join should be used for collated strings") { + test("hash join should be used for collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1602,47 +1626,48 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA ", 2), Row("aa", 1, "aa", 2))) ) - - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") - sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") + sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } + } } } - }) + } } - test("hash join should be used for arrays of collated strings") { + test("hash join should be used for arrays of collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1660,47 +1685,50 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq("aa"), 1, Seq("AA "), 2), Row(Seq("aa"), 1, Seq("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. - function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. + function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } + } } } - }) + } } - test("hash join should be used for arrays of arrays of collated strings") { + test("hash join should be used for arrays of arrays of collated strings " + + "if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1718,51 +1746,53 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA ")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") - sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + - s" (array(array('${t.data1}')), 2)") + sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + + s" (array(array('${t.data1}')), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } + } } } - }) + } } - test("hash join should respect collation for struct of strings") { + test("hash and sort merge join should respect collation for struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1779,43 +1809,36 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row(Row("aa"), 1, Row("AA "), 2), Row(Row("aa"), 1, Row("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + - s" (named_struct('f', '${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + + s" (named_struct('f', '${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // Confirm that hash join is used instead of sort merge join. - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } - }) + } } - test("hash join should respect collation for struct of array of struct of strings") { + test("hash and sort merge join should respect collation " + + "for struct of array of struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1835,43 +1858,36 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA "))), 2), Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))) ) - testCases.foreach(t => { + + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT>>, " + - s"i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', '${t.data1}'))), 1)" - ) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT>>, " + + s"i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data1}'))), 1)") - sql(s"CREATE TABLE $t2 (y STRUCT>>, " + - s"j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', '${t.data2}'))), 2)" - + s", (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") + sql(s"CREATE TABLE $t2 (y STRUCT>>, " + + s"j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data2}'))), 2), (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } - }) + } } test("rewrite with collationkey should be a non-excludable rule") { @@ -1931,31 +1947,27 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "'a', 'a', 1", "'A', 'A ', 1", Row("a", "a", 1, "A", "A ", 1)) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (${t.data1})") - sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (${t.data2})") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (${t.data1})") + sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (${t.data2})") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) + } } - }) + } } test("hll sketch aggregate should respect collation") { From bd94419c988ba115c6c05df18f60e17c066dfe78 Mon Sep 17 00:00:00 2001 From: Ruzel Ibragimov Date: Wed, 13 Nov 2024 16:30:16 +0100 Subject: [PATCH 24/39] [SPARK-50226][SQL] Correct MakeDTInterval and MakeYMInterval to catch Java exceptions ### What changes were proposed in this pull request? `MakeDTInterval` and `MakeYMInterval` do not catch Java exceptions in nullSafeEval like it does `MakeInterval`. So we making behavior similar. ### Why are the changes needed? To show to users readable nice error message. ### Does this PR introduce _any_ user-facing change? Improved error message ### How was this patch tested? There already were few tests to check behavior, I just changed expected error type. ### Was this patch authored or co-authored using generative AI tooling? Yes, Copilot used. Closes #48773 from gotocoding-DB/SPARK-50226-overflow-error. Authored-by: Ruzel Ibragimov Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 14 +++- .../expressions/intervalExpressions.scala | 43 +++++++--- .../sql/catalyst/util/IntervalMathUtils.scala | 9 ++- .../sql/catalyst/util/IntervalUtils.scala | 18 +++-- .../sql/errors/QueryExecutionErrors.scala | 23 +++--- .../IntervalExpressionsSuite.scala | 6 +- .../sql-tests/results/ansi/interval.sql.out | 80 ++++++++++++------- .../sql-tests/results/interval.sql.out | 80 ++++++++++++------- .../spark/sql/DataFrameAggregateSuite.scala | 47 +++++++---- 9 files changed, 210 insertions(+), 110 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e51b35c0accc2..5e1c3f46fd110 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2012,8 +2012,20 @@ }, "INTERVAL_ARITHMETIC_OVERFLOW" : { "message" : [ - "." + "Integer overflow while operating with intervals." ], + "subClass" : { + "WITHOUT_SUGGESTION" : { + "message" : [ + "Try devising appropriate values for the interval parameters." + ] + }, + "WITH_SUGGESTION" : { + "message" : [ + "Use to tolerate overflow and return NULL instead." + ] + } + }, "sqlState" : "22015" }, "INTERVAL_DIVIDED_BY_ZERO" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1ce7dfd39acc6..a7b67f55d8cd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -481,7 +481,7 @@ case class MakeDTInterval( hours: Expression, mins: Expression, secs: Expression) - extends QuaternaryExpression with ImplicitCastInputTypes { + extends QuaternaryExpression with ImplicitCastInputTypes with SupportQueryContext { override def nullIntolerant: Boolean = true def this( @@ -514,13 +514,15 @@ case class MakeDTInterval( day.asInstanceOf[Int], hour.asInstanceOf[Int], min.asInstanceOf[Int], - sec.asInstanceOf[Decimal]) + sec.asInstanceOf[Decimal], + origin.context) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (day, hour, min, sec) => { + val errorContext = getContextOrNullCode(ctx) val iu = IntervalUtils.getClass.getName.stripSuffix("$") - s"$iu.makeDayTimeInterval($day, $hour, $min, $sec)" + s"$iu.makeDayTimeInterval($day, $hour, $min, $sec, $errorContext)" }) } @@ -532,6 +534,8 @@ case class MakeDTInterval( mins: Expression, secs: Expression): MakeDTInterval = copy(days, hours, mins, secs) + + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( @@ -556,7 +560,7 @@ case class MakeDTInterval( group = "datetime_funcs") // scalastyle:on line.size.limit case class MakeYMInterval(years: Expression, months: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with Serializable with SupportQueryContext { override def nullIntolerant: Boolean = true def this(years: Expression) = this(years, Literal(0)) @@ -568,17 +572,28 @@ case class MakeYMInterval(years: Expression, months: Expression) override def dataType: DataType = YearMonthIntervalType() override def nullSafeEval(year: Any, month: Any): Any = { - Math.toIntExact(Math.addExact(month.asInstanceOf[Number].longValue(), - Math.multiplyExact(year.asInstanceOf[Number].longValue(), MONTHS_PER_YEAR))) + try { + Math.toIntExact( + Math.addExact(month.asInstanceOf[Int], + Math.multiplyExact(year.asInstanceOf[Int], MONTHS_PER_YEAR))) + } catch { + case _: ArithmeticException => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(origin.context) + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (years, months) => { + nullSafeCodeGen(ctx, ev, (years, months) => { val math = classOf[Math].getName.stripSuffix("$") + val errorContext = getContextOrNullCode(ctx) + // scalastyle:off line.size.limit s""" - |$math.toIntExact(java.lang.Math.addExact($months, - | $math.multiplyExact($years, $MONTHS_PER_YEAR))) - |""".stripMargin + |try { + | ${ev.value} = $math.toIntExact($math.addExact($months, $math.multiplyExact($years, $MONTHS_PER_YEAR))); + |} catch (java.lang.ArithmeticException e) { + | throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError($errorContext); + |}""".stripMargin + // scalastyle:on line.size.limit }) } @@ -587,6 +602,10 @@ case class MakeYMInterval(years: Expression, months: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = copy(years = newLeft, months = newRight) + + override def initQueryContext(): Option[QueryContext] = { + Some(origin.context) + } } // Multiply an year-month interval by a numeric @@ -699,8 +718,8 @@ trait IntervalDivide { context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { - throw QueryExecutionErrors.intervalArithmeticOverflowError( - "Interval value overflows after being divided by -1", "try_divide", context) + throw QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError( + "try_divide", context) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala index c935c60573763..756f2598f13f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalMathUtils.scala @@ -35,12 +35,15 @@ object IntervalMathUtils { def negateExact(a: Long): Long = withOverflow(Math.negateExact(a)) - private def withOverflow[A](f: => A, hint: String = ""): A = { + private def withOverflow[A](f: => A, suggestedFunc: String = ""): A = { try { f } catch { - case e: ArithmeticException => - throw QueryExecutionErrors.intervalArithmeticOverflowError(e.getMessage, hint, null) + case _: ArithmeticException if suggestedFunc.isEmpty => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context = null) + case _: ArithmeticException => + throw QueryExecutionErrors.withSuggestionIntervalArithmeticOverflowError( + suggestedFunc, context = null) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 90c802b7e28df..39a07990dea39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import scala.util.control.NonFatal -import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable} +import org.apache.spark.{QueryContext, SparkIllegalArgumentException, SparkThrowable} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -782,13 +782,19 @@ object IntervalUtils extends SparkIntervalUtils { days: Int, hours: Int, mins: Int, - secs: Decimal): Long = { + secs: Decimal, + context: QueryContext): Long = { assert(secs.scale == 6, "Seconds fractional must have 6 digits for microseconds") var micros = secs.toUnscaledLong - micros = Math.addExact(micros, Math.multiplyExact(days, MICROS_PER_DAY)) - micros = Math.addExact(micros, Math.multiplyExact(hours, MICROS_PER_HOUR)) - micros = Math.addExact(micros, Math.multiplyExact(mins, MICROS_PER_MINUTE)) - micros + try { + micros = Math.addExact(micros, Math.multiplyExact(days, MICROS_PER_DAY)) + micros = Math.addExact(micros, Math.multiplyExact(hours, MICROS_PER_HOUR)) + micros = Math.addExact(micros, Math.multiplyExact(mins, MICROS_PER_MINUTE)) + micros + } catch { + case _: ArithmeticException => + throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context) + } } def intToYearMonthInterval(v: Int, startField: Byte, endField: Byte): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 09836995925ea..fb39d3c5d7c6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -636,18 +636,21 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = "") } - def intervalArithmeticOverflowError( - message: String, - hint: String = "", + def withSuggestionIntervalArithmeticOverflowError( + suggestedFunc: String, context: QueryContext): ArithmeticException = { - val alternative = if (hint.nonEmpty) { - s" Use '$hint' to tolerate overflow and return NULL instead." - } else "" new SparkArithmeticException( - errorClass = "INTERVAL_ARITHMETIC_OVERFLOW", - messageParameters = Map( - "message" -> message, - "alternative" -> alternative), + errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + messageParameters = Map("functionName" -> toSQLId(suggestedFunc)), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def withoutSuggestionIntervalArithmeticOverflowError( + context: QueryContext): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + messageParameters = Map(), context = getQueryContext(context), summary = getSummary(context)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 78bc77b9dc2ab..8fb72ad53062e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -316,7 +316,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros) val durationExpr = MakeDTInterval(Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, Decimal.MAX_LONG_DIGITS, 6))) - checkExceptionInExpression[ArithmeticException](durationExpr, EmptyRow, "") + checkExceptionInExpression[ArithmeticException]( + durationExpr, "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION") } check(millis = -123) @@ -528,7 +529,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq(MakeYMInterval(Literal(178956970), Literal(8)), MakeYMInterval(Literal(-178956970), Literal(-9))) .foreach { ym => - checkExceptionInExpression[ArithmeticException](ym, "integer overflow") + checkExceptionInExpression[ArithmeticException]( + ym, "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION") } def checkImplicitEvaluation(expr: Expression, value: Any): Unit = { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 766bfba7696f0..4e012df792dea 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -936,8 +936,18 @@ select make_dt_interval(2147483647) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -long overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 35, + "fragment" : "make_dt_interval(2147483647)" + } ] +} -- !query @@ -977,8 +987,18 @@ select make_ym_interval(178956970, 8) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 37, + "fragment" : "make_ym_interval(178956970, 8)" + } ] +} -- !query @@ -994,8 +1014,18 @@ select make_ym_interval(-178956970, -9) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 39, + "fragment" : "make_ym_interval(-178956970, -9)" + } ] +} -- !query @@ -2493,12 +2523,8 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", - "sqlState" : "22015", - "messageParameters" : { - "alternative" : "", - "message" : "integer overflow" - } + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015" } @@ -2509,11 +2535,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_subtract' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_subtract`" } } @@ -2525,11 +2550,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_add`" } } @@ -2838,11 +2862,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2861,11 +2884,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2918,11 +2940,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2941,11 +2962,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 7eed2d42da043..a8a0423bdb3e0 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -823,8 +823,18 @@ select make_dt_interval(2147483647) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -long overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 35, + "fragment" : "make_dt_interval(2147483647)" + } ] +} -- !query @@ -864,8 +874,18 @@ select make_ym_interval(178956970, 8) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 37, + "fragment" : "make_ym_interval(178956970, 8)" + } ] +} -- !query @@ -881,8 +901,18 @@ select make_ym_interval(-178956970, -9) -- !query schema struct<> -- !query output -java.lang.ArithmeticException -integer overflow +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 39, + "fragment" : "make_ym_interval(-178956970, -9)" + } ] +} -- !query @@ -2316,12 +2346,8 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", - "sqlState" : "22015", - "messageParameters" : { - "alternative" : "", - "message" : "integer overflow" - } + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", + "sqlState" : "22015" } @@ -2332,11 +2358,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_subtract' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_subtract`" } } @@ -2348,11 +2373,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", - "message" : "integer overflow" + "functionName" : "`try_add`" } } @@ -2661,11 +2685,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2684,11 +2707,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2741,11 +2763,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", @@ -2764,11 +2785,10 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException { - "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW", + "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", "sqlState" : "22015", "messageParameters" : { - "alternative" : " Use 'try_divide' to tolerate overflow and return NULL instead.", - "message" : "Interval value overflows after being divided by -1" + "functionName" : "`try_divide`" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7ebcb280def6e..6348e5f315395 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1485,15 +1486,22 @@ class DataFrameAggregateSuite extends QueryTest val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day") - val error = intercept[SparkArithmeticException] { - checkAnswer(df2.select(sum($"year-month")), Nil) - } - assert(error.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] integer overflow") - val error2 = intercept[SparkArithmeticException] { - checkAnswer(df2.select(sum($"day")), Nil) - } - assert(error2.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] long overflow") + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(sum($"year-month")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) + + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(sum($"day")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) } test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") { @@ -1620,15 +1628,22 @@ class DataFrameAggregateSuite extends QueryTest val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), (Period.ofMonths(10), Duration.ofDays(10))) .toDF("year-month", "day") - val error = intercept[SparkArithmeticException] { - checkAnswer(df2.select(avg($"year-month")), Nil) - } - assert(error.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] integer overflow") - val error2 = intercept[SparkArithmeticException] { - checkAnswer(df2.select(avg($"day")), Nil) - } - assert(error2.getMessage contains "[INTERVAL_ARITHMETIC_OVERFLOW] long overflow") + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(avg($"year-month")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) + + checkError( + exception = intercept[SparkArithmeticException] { + checkAnswer(df2.select(avg($"day")), Nil) + }, + condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION", + parameters = Map("functionName" -> toSQLId("try_add")) + ) val df3 = intervalData.filter($"class" > 4) val avgDF3 = df3.select(avg($"year-month"), avg($"day")) From bc9b2597ea2d99620918f809a3db8739968e42a3 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 13 Nov 2024 16:36:02 +0100 Subject: [PATCH 25/39] [SPARK-50066][SQL] Codegen Support for `SchemaOfXml` (by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `schema_of_xml`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: XmlFunctionsSuite#`*schema_of_xml*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48594 from panbingkun/SPARK-50066. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../xml/XmlExpressionEvalUtils.scala | 42 +++++++++++++++++++ .../catalyst/expressions/xmlExpressions.scala | 34 +++++++-------- 2 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala new file mode 100644 index 0000000000000..dff88475327a2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.sql.catalyst.xml.XmlInferSchema +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +object XmlExpressionEvalUtils { + + def schemaOfXml(xmlInferSchema: XmlInferSchema, xml: UTF8String): UTF8String = { + val dataType = xmlInferSchema.infer(xml.toString).get match { + case st: StructType => + xmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil)) + case at: ArrayType if at.elementType.isInstanceOf[StructType] => + xmlInferSchema + .canonicalizeType(at.elementType) + .map(ArrayType(_, containsNull = at.containsNull)) + .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) + case other: DataType => + xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType) + } + + UTF8String.fromString(dataType.sql) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 196c0793e6193..6f004cbce4262 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -21,7 +21,9 @@ import java.io.CharArrayWriter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.xml.XmlExpressionEvalUtils import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, FailureSafeParser, PermissiveMode} import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} @@ -149,7 +151,9 @@ case class XmlToStructs( case class SchemaOfXml( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -192,26 +196,20 @@ case class SchemaOfXml( } } - override def eval(v: InternalRow): Any = { - val dataType = xmlInferSchema.infer(xml.toString).get match { - case st: StructType => - xmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil)) - case at: ArrayType if at.elementType.isInstanceOf[StructType] => - xmlInferSchema - .canonicalizeType(at.elementType) - .map(ArrayType(_, containsNull = at.containsNull)) - .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) - case other: DataType => - xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType) - } - - UTF8String.fromString(dataType.sql) - } - override def prettyName: String = "schema_of_xml" override protected def withNewChildInternal(newChild: Expression): SchemaOfXml = copy(child = newChild) + + @transient private lazy val xmlInferSchemaObjectType = ObjectType(classOf[XmlInferSchema]) + + override def replacement: Expression = StaticInvoke( + XmlExpressionEvalUtils.getClass, + dataType, + "schemaOfXml", + Seq(Literal(xmlInferSchema, xmlInferSchemaObjectType), child), + Seq(xmlInferSchemaObjectType, child.dataType) + ) } /** From 558fc89f8ccf631cf12e9838d57c6aaa77696c03 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 13 Nov 2024 16:40:16 +0100 Subject: [PATCH 26/39] [SPARK-49611][SQL][FOLLOW-UP] Make collations TVF consistent and return null on no result for country and language ### What changes were proposed in this pull request? It was noticed that we return null for country and language for collations TVF when collation is UTF8_*, but when information is missing in ICU we return empty string. ### Why are the changes needed? Making behaviour consistent. ### Does this PR introduce _any_ user-facing change? No, this is all in Spark 4.0, so addition of this TVF was not released yet. ### How was this patch tested? Existing test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48835 from mihailom-db/fix-collations-table. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../sql/catalyst/util/CollationFactory.java | 6 ++++-- .../org/apache/spark/sql/CollationSuite.scala | 16 ++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index ad5e5ae845f85..4064f830e92d8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -1023,12 +1023,14 @@ protected Collation buildCollation() { @Override protected CollationMeta buildCollationMeta() { + String language = ICULocaleMap.get(locale).getDisplayLanguage(); + String country = ICULocaleMap.get(locale).getDisplayCountry(); return new CollationMeta( CATALOG, SCHEMA, normalizedCollationName(), - ICULocaleMap.get(locale).getDisplayLanguage(), - ICULocaleMap.get(locale).getDisplayCountry(), + language.isEmpty() ? null : language, + country.isEmpty() ? null : country, VersionInfo.ICU_VERSION.toString(), COLLATION_PAD_ATTRIBUTE, accentSensitivity == AccentSensitivity.AS, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9716d342bb6bc..f5cb30809ae50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -2037,21 +2037,21 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", null), Row("SYSTEM", "BUILTIN", "UTF8_LCASE", null, null, "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", null), - Row("SYSTEM", "BUILTIN", "UNICODE", "", "", + Row("SYSTEM", "BUILTIN", "UNICODE", null, null, "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "UNICODE_AI", "", "", + Row("SYSTEM", "BUILTIN", "UNICODE_AI", null, null, "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", + Row("SYSTEM", "BUILTIN", "UNICODE_CI", null, null, "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "UNICODE_CI_AI", "", "", + Row("SYSTEM", "BUILTIN", "UNICODE_CI_AI", null, null, "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "af", "Afrikaans", "", + Row("SYSTEM", "BUILTIN", "af", "Afrikaans", null, "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "af_AI", "Afrikaans", "", + Row("SYSTEM", "BUILTIN", "af_AI", "Afrikaans", null, "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", + Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", null, "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", icvVersion), - Row("SYSTEM", "BUILTIN", "af_CI_AI", "Afrikaans", "", + Row("SYSTEM", "BUILTIN", "af_CI_AI", "Afrikaans", null, "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", icvVersion))) checkAnswer(sql("SELECT * FROM collations() WHERE NAME LIKE '%UTF8_BINARY%'"), From 7b1b450bb65b49f8a5c9e2d9ebd1e01a2e4e3880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Wed, 13 Nov 2024 16:43:59 +0100 Subject: [PATCH 27/39] Revert [SPARK-50215][SQL] Refactored StringType pattern matching in jdbc code stack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? I propose reverting the PR for changing pattern matching of `StringType` in the jdbc code stack, since it may lead to collated column being mapped to uncollated column in some dialects. For the time being, this is not the correct behavior. ### Why are the changes needed? These changes are needed in order to preserve proper behavior in the dialects regarding datatype mapping. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No testing was needed. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48833 from vladanvasi-db/vladanvasi-db/jdbc-refactor-revert. Authored-by: Vladan Vasić Signed-off-by: Max Gekk --- .../src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala | 2 +- .../scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala | 2 +- .../src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala | 2 +- .../src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala | 2 +- .../scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala | 2 +- .../src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala | 2 +- .../main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala | 2 +- .../main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala | 2 +- .../main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 3256803f60395..2f54f1f62fde1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -101,7 +101,7 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case _: StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType if conf.legacyDB2BooleanMappingEnabled => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala index 3b855b376967d..af77f8575dd86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DatabricksDialect.scala @@ -44,7 +44,7 @@ private case class DatabricksDialect() extends JdbcDialect with NoLegacyJDBCErro override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case DoubleType => Some(JdbcType("DOUBLE", java.sql.Types.DOUBLE)) - case _: StringType => Some(JdbcType("STRING", java.sql.Types.VARCHAR)) + case StringType => Some(JdbcType("STRING", java.sql.Types.VARCHAR)) case BinaryType => Some(JdbcType("BINARY", java.sql.Types.BINARY)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index f78e155d485db..7b65a01b5e702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -44,7 +44,7 @@ private case class DerbyDialect() extends JdbcDialect with NoLegacyJDBCError { } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case _: StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case ByteType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 5e5ba797ca608..798ecb5b36ff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -73,7 +73,7 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError { } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case _: StringType => Option(JdbcType("CLOB", Types.CLOB)) + case StringType => Option(JdbcType("CLOB", Types.CLOB)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case ShortType | ByteType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index a29f3d9550d1d..7d476d43e5c7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -135,7 +135,7 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) case TimestampNTZType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) - case _: StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) + case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY)) case ShortType if !SQLConf.get.legacyMsSqlServerNumericMappingEnabled => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index c4f2793707e5b..dd0118d875998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -256,7 +256,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No // See SPARK-35446: MySQL treats REAL as a synonym to DOUBLE by default // We override getJDBCType so that FloatType is mapped to FLOAT instead case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT)) - case _: StringType => Option(JdbcType("LONGTEXT", java.sql.Types.LONGVARCHAR)) + case StringType => Option(JdbcType("LONGTEXT", java.sql.Types.LONGVARCHAR)) case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT)) case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) // scalastyle:off line.size.limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 9c8a6bf5e145f..a73a34c646356 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -121,7 +121,7 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N case DoubleType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE)) case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT)) case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT)) - case _: StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) + case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) case VarcharType(n) => Some(JdbcType(s"VARCHAR2($n)", java.sql.Types.VARCHAR)) case TimestampType if !conf.legacyOracleTimestampMappingEnabled => Some(JdbcType("TIMESTAMP WITH LOCAL TIME ZONE", TIMESTAMP_LTZ)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 1265550b3f19d..8341063e09890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -145,7 +145,7 @@ private case class PostgresDialect() } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case _: StringType => Some(JdbcType("TEXT", Types.VARCHAR)) + case StringType => Some(JdbcType("TEXT", Types.VARCHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index c7d8e899d71b0..322b259485f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -40,7 +40,7 @@ private case class TeradataDialect() extends JdbcDialect with NoLegacyJDBCError supportedFunctions.contains(funcName) override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case _: StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) + case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case ByteType => Option(JdbcType("BYTEINT", java.sql.Types.TINYINT)) case _ => None From 87ad4b4a2cfbb1b1c5d5374d3fea848b1e0dac8b Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 13 Nov 2024 09:22:08 -0800 Subject: [PATCH 28/39] [SPARK-50139][INFRA][SS][PYTHON] Introduce scripts to re-generate and checking StateMessage_pb2.py and StateMessage_pb2.pyi ### What changes were proposed in this pull request? This pr includes the following changes: 1. Refactor the `dev/connect-gen-protos.sh` script to support the generation of `.py` files from `.proto` files for both the `connect` and `streaming` modules simultaneously. Rename the script to `dev/gen-protos.sh`. Additionally, to maintain compatibility with previous development practices, this pull request (PR) introduces `dev/connect-gen-protos.sh` and `dev/streaming-gen-protos.sh` as wrappers around `dev/gen-protos.sh`. After this PR, you can use: ``` dev/gen-protos.sh connect dev/gen-protos.sh streaming ``` or ``` dev/connect-gen-protos.sh dev/streaming-gen-protos.sh ``` to regenerate the corresponding `.py` files for the respective modules. 2. Refactor the `dev/connect-check-protos.py` script to check the generated results for both the `connect` and `streaming` modules simultaneously, and rename it to `dev/check-protos.py`. Additionally, update the invocation of the check script in `build_and_test.yml`. ### Why are the changes needed? Provid tools for re-generate and checking `StateMessage_pb2.py` and `StateMessage_pb2.pyi`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #48815 from LuciferYang/streaming-gen-protos. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 6 +- ...onnect-check-protos.py => check-protos.py} | 23 +- dev/connect-gen-protos.sh | 78 +- dev/gen-protos.sh | 127 ++ dev/streaming-gen-protos.sh | 27 + dev/tox.ini | 1 + .../sql/streaming/proto/StateMessage_pb2.py | 173 +- .../sql/streaming/proto/StateMessage_pb2.pyi | 1552 ++++++++++++----- sql/core/src/main/buf.gen.yaml | 24 + sql/core/src/main/buf.work.yaml | 19 + 10 files changed, 1423 insertions(+), 607 deletions(-) rename dev/{connect-check-protos.py => check-protos.py} (73%) create mode 100755 dev/gen-protos.sh create mode 100755 dev/streaming-gen-protos.sh create mode 100644 sql/core/src/main/buf.gen.yaml create mode 100644 sql/core/src/main/buf.work.yaml diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index fc0959c5a415a..4a3707404bccf 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -671,8 +671,12 @@ jobs: run: | python3.11 -m pip install 'black==23.9.1' 'protobuf==5.28.3' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' python3.11 -m pip list - - name: Python CodeGen check + - name: Python CodeGen check for branch-3.5 + if: inputs.branch == 'branch-3.5' run: ./dev/connect-check-protos.py + - name: Python CodeGen check + if: inputs.branch != 'branch-3.5' + run: ./dev/check-protos.py # Static analysis lint: diff --git a/dev/connect-check-protos.py b/dev/check-protos.py similarity index 73% rename from dev/connect-check-protos.py rename to dev/check-protos.py index 9ba56bae6b19c..bfca8b27be21c 100755 --- a/dev/connect-check-protos.py +++ b/dev/check-protos.py @@ -18,7 +18,7 @@ # # Utility for checking whether generated codes in PySpark are out of sync. -# usage: ./dev/connect-check-protos.py +# usage: ./dev/check-protos.py import os import sys @@ -43,12 +43,12 @@ def run_cmd(cmd): return subprocess.check_output(cmd.split(" ")).decode("utf-8") -def check_connect_protos(): - print("Start checking the generated codes in pyspark-connect.") - with tempfile.TemporaryDirectory(prefix="check_connect_protos") as tmp: - run_cmd(f"{SPARK_HOME}/dev/connect-gen-protos.sh {tmp}") +def check_protos(module_name, cmp_path, proto_path): + print(f"Start checking the generated codes in pyspark-${module_name}.") + with tempfile.TemporaryDirectory(prefix=f"check_${module_name}__protos") as tmp: + run_cmd(f"{SPARK_HOME}/dev/gen-protos.sh {module_name} {tmp}") result = filecmp.dircmp( - f"{SPARK_HOME}/python/pyspark/sql/connect/proto/", + f"{SPARK_HOME}/{cmp_path}", tmp, ignore=["__init__.py", "__pycache__"], ) @@ -71,14 +71,17 @@ def check_connect_protos(): success = False if success: - print("Finish checking the generated codes in pyspark-connect: SUCCESS") + print(f"Finish checking the generated codes in pyspark-${module_name}: SUCCESS") else: fail( "Generated files for pyspark-connect are out of sync! " - "If you have touched files under sql/connect/common/src/main/protobuf/, " - "please run ./dev/connect-gen-protos.sh. " + f"If you have touched files under ${proto_path}, " + f"please run ./dev/${module_name}-gen-protos.sh. " "If you haven't touched any file above, please rebase your PR against main branch." ) -check_connect_protos() +check_protos( + "connect", "python/pyspark/sql/connect/proto/", "sql/connect/common/src/main/protobuf/" +) +check_protos("streaming", "python/pyspark/sql/streaming/proto/", "sql/core/src/main/protobuf/") diff --git a/dev/connect-gen-protos.sh b/dev/connect-gen-protos.sh index 2805908890eec..8ed323cc42599 100755 --- a/dev/connect-gen-protos.sh +++ b/dev/connect-gen-protos.sh @@ -24,80 +24,4 @@ if [[ $# -gt 1 ]]; then exit -1 fi - -SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" -cd "$SPARK_HOME" - - -OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/connect/proto/ -if [[ $# -eq 1 ]]; then - rm -Rf $1 - mkdir -p $1 - OUTPUT_PATH=$1 -fi - -pushd sql/connect/common/src/main - -LICENSE=$(cat <<'EOF' -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -EOF) -echo "$LICENSE" > /tmp/tmp_licence - - -# Delete the old generated protobuf files. -rm -Rf gen - -# Now, regenerate the new files -buf generate --debug -vvv - -# We need to edit the generate python files to account for the actual package location and not -# the one generated by proto. -for f in `find gen/proto/python -name "*.py*"`; do - # First fix the imports. - if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then - sed -e 's/from spark.connect import/from pyspark.sql.connect.proto import/g' $f > $f.tmp - mv $f.tmp $f - # Now fix the module name in the serialized descriptor. - sed -e "s/DESCRIPTOR, 'spark.connect/DESCRIPTOR, 'pyspark.sql.connect.proto/g" $f > $f.tmp - mv $f.tmp $f - elif [[ $f == *.pyi ]]; then - sed -e 's/import spark.connect./import pyspark.sql.connect.proto./g' -e 's/spark.connect./pyspark.sql.connect.proto./g' -e '/ *@typing_extensions\.final/d' $f > $f.tmp - mv $f.tmp $f - fi - - # Prepend the Apache licence header to the files. - cp $f $f.bak - cat /tmp/tmp_licence $f.bak > $f - - LC=$(wc -l < $f) - echo $LC - if [[ $f == *_grpc.py && $LC -eq 20 ]]; then - rm $f - fi - rm $f.bak -done - -black --config $SPARK_HOME/dev/pyproject.toml gen/proto/python - -# Last step copy the result files to the destination module. -for f in `find gen/proto/python -name "*.py*"`; do - cp $f $OUTPUT_PATH -done - -# Clean up everything. -rm -Rf gen +./dev/gen-protos.sh connect "$@" diff --git a/dev/gen-protos.sh b/dev/gen-protos.sh new file mode 100755 index 0000000000000..d169964feb853 --- /dev/null +++ b/dev/gen-protos.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +OUTPUT_PATH="" +MODULE="" +SOURCE_MODULE="" +TARGET_MODULE="" + +function usage() { + echo "Illegal number of parameters." + echo "Usage:./dev/gen-protos.sh [connect|streaming] [output_path]" + exit -1 +} + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +if [[ $1 == "connect" ]]; then + MODULE="connect" + OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/connect/proto/ + SOURCE_MODULE="spark.connect" + TARGET_MODULE="pyspark.sql.connect.proto" +elif [[ $1 == "streaming" ]]; then + MODULE="streaming" + OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/streaming/proto/ + SOURCE_MODULE="org.apache.spark.sql.execution.streaming" + TARGET_MODULE="pyspark.sql.streaming.proto" +else + usage +fi + +if [[ $# -eq 2 ]]; then + rm -Rf $2 + mkdir -p $2 + OUTPUT_PATH=$2 +fi + +if [[ $MODULE == "connect" ]]; then + pushd sql/connect/common/src/main +elif [[ $MODULE == "streaming" ]]; then + pushd sql/core/src/main +fi + +LICENSE=$(cat <<'EOF' +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +EOF) +echo "$LICENSE" > /tmp/tmp_licence + +# Delete the old generated protobuf files. +rm -Rf gen + +# Now, regenerate the new files +buf generate --debug -vvv + +# We need to edit the generate python files to account for the actual package location and not +# the one generated by proto. +for f in `find gen/proto/python -name "*.py*"`; do + # First fix the imports. + if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then + sed -e "s/from ${SOURCE_MODULE} import/from ${TARGET_MODULE} import/g" $f > $f.tmp + mv $f.tmp $f + # Now fix the module name in the serialized descriptor. + sed -e "s/DESCRIPTOR, '${SOURCE_MODULE}/DESCRIPTOR, '${TARGET_MODULE}/g" $f > $f.tmp + mv $f.tmp $f + elif [[ $f == *.pyi ]]; then + sed -e "s/import ${SOURCE_MODULE}./import ${TARGET_MODULE}./g" -e "s/${SOURCE_MODULE}./${TARGET_MODULE}./g" -e '/ *@typing_extensions\.final/d' $f > $f.tmp + mv $f.tmp $f + fi + + # Prepend the Apache licence header to the files. + cp $f $f.bak + cat /tmp/tmp_licence $f.bak > $f + + LC=$(wc -l < $f) + echo $LC + if [[ $f == *_grpc.py && $LC -eq 20 ]]; then + rm $f + fi + rm $f.bak +done + +black --config $SPARK_HOME/dev/pyproject.toml gen/proto/python + +# Last step copy the result files to the destination module. +for f in `find gen/proto/python -name "*.py*"`; do + cp $f $OUTPUT_PATH +done + +# Clean up everything. +rm -Rf gen diff --git a/dev/streaming-gen-protos.sh b/dev/streaming-gen-protos.sh new file mode 100755 index 0000000000000..3d80bda4fb94e --- /dev/null +++ b/dev/streaming-gen-protos.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +set -ex + +if [[ $# -gt 1 ]]; then + echo "Illegal number of parameters." + echo "Usage: ./dev/streaming-gen-protos.sh [path]" + exit -1 +fi + +./dev/gen-protos.sh streaming "$@" diff --git a/dev/tox.ini b/dev/tox.ini index 47b1b4a9d7832..05a6b16a03bd9 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -59,5 +59,6 @@ exclude = *python/pyspark/worker.pyi, *python/pyspark/java_gateway.pyi, *python/pyspark/sql/connect/proto/*, + *python/pyspark/sql/streaming/proto/*, */venv/* max-line-length = 100 diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 46bed10c45588..0a54690513a39 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -17,8 +17,8 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE -# source: StateMessage.proto -# Protobuf Python Version: 5.27.3 +# source: org/apache/spark/sql/execution/streaming/StateMessage.proto +# Protobuf Python Version: 5.28.3 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -27,7 +27,12 @@ from google.protobuf.internal import builder as _builder _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, 5, 27, 3, "", "StateMessage.proto" + _runtime_version.Domain.PUBLIC, + 5, + 28, + 3, + "", + "org/apache/spark/sql/execution/streaming/StateMessage.proto", ) # @@protoc_insertion_point(imports) @@ -35,90 +40,92 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xbf\x03\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x12T\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"W\n\x1cStateResponseWithLongTypeVal\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03"\xc6\x04\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12_\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00\x12Z\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xa8\x02\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x12T\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"\xda\x01\n\x0cTimerRequest\x12^\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00\x12`\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00\x42\x08\n\x06method"\xd4\x01\n\x11TimerValueRequest\x12_\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00\x12T\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00\x42\x08\n\x06method"/\n\x12\x45xpiryTimerRequest\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x9a\x01\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x1b\n\x13mapStateValueSchema\x18\x03 \x01(\t\x12\x46\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\x8f\x02\n\x15TimerStateCallCommand\x12Q\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00\x12M\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00\x12J\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00\x42\x08\n\x06method"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\xe1\x05\n\x0cMapStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12L\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00\x12R\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00\x12R\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00\x12L\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00\x12\x44\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00\x12H\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00\x12N\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00\x12\x46\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"*\n\rRegisterTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"(\n\x0b\x44\x65leteTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03" \n\nListTimers\x12\x12\n\niteratorId\x18\x01 \x01(\t"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\x1b\n\x08GetValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\x1e\n\x0b\x43ontainsKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"-\n\x0bUpdateValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c"\x1e\n\x08Iterator\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1a\n\x04Keys\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\x06Values\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\tRemoveKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' # noqa: E501 + b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xa0\x04\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14stateVariableRequest\x12\x8c\x01\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00R\x1aimplicitGroupingKeyRequest\x12\x62\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00R\x0ctimerRequestB\x08\n\x06method"i\n\rStateResponse\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x0cR\x05value"x\n\x1cStateResponseWithLongTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x03R\x05value"\xa0\x05\n\x15StatefulProcessorCall\x12h\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00R\x0esetHandleState\x12h\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\rgetValueState\x12\x66\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0cgetListState\x12\x64\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0bgetMapState\x12o\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00R\x0etimerStateCall\x12j\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0e\x64\x65leteIfExistsB\x08\n\x06method"\xd5\x02\n\x14StateVariableRequest\x12h\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00R\x0evalueStateCall\x12\x65\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00R\rlistStateCall\x12\x62\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00R\x0cmapStateCallB\x08\n\x06method"\x83\x02\n\x1aImplicitGroupingKeyRequest\x12h\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00R\x0esetImplicitKey\x12q\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00R\x11removeImplicitKeyB\x08\n\x06method"\x81\x02\n\x0cTimerRequest\x12q\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00R\x11timerValueRequest\x12t\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00R\x12\x65xpiryTimerRequestB\x08\n\x06method"\xf6\x01\n\x11TimerValueRequest\x12s\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00R\x12getProcessingTimer\x12\x62\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00R\x0cgetWatermarkB\x08\n\x06method"B\n\x12\x45xpiryTimerRequest\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\xc7\x01\n\x10StateCallCommand\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x30\n\x13mapStateValueSchema\x18\x03 \x01(\tR\x13mapStateValueSchema\x12K\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfigR\x03ttl"\xa7\x02\n\x15TimerStateCallCommand\x12[\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00R\x08register\x12U\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00R\x06\x64\x65lete\x12P\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00R\x04listB\x08\n\x06method"\x92\x03\n\x0eValueStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12G\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00R\x03get\x12n\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00R\x10valueStateUpdate\x12M\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xdf\x04\n\rListStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12\x62\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00R\x0clistStateGet\x12\x62\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00R\x0clistStatePut\x12_\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00R\x0b\x61ppendValue\x12\\\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00R\nappendList\x12M\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xc2\x06\n\x0cMapStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12V\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00R\x08getValue\x12_\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00R\x0b\x63ontainsKey\x12_\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00R\x0bupdateValue\x12V\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00R\x08iterator\x12J\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00R\x04keys\x12P\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00R\x06values\x12Y\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00R\tremoveKey\x12M\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method""\n\x0eSetImplicitKey\x12\x10\n\x03key\x18\x01 \x01(\x0cR\x03key"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"=\n\rRegisterTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs";\n\x0b\x44\x65leteTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs",\n\nListTimers\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x10ValueStateUpdate\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x07\n\x05\x43lear".\n\x0cListStateGet\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"\x0e\n\x0cListStatePut"#\n\x0b\x41ppendValue\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x0c\n\nAppendList"$\n\x08GetValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"\'\n\x0b\x43ontainsKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"=\n\x0bUpdateValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05value"*\n\x08Iterator\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"&\n\x04Keys\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x06Values\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"%\n\tRemoveKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"c\n\x0eSetHandleState\x12Q\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleStateR\x05state"+\n\tTTLConfig\x12\x1e\n\ndurationMs\x18\x01 \x01(\x05R\ndurationMs*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "pyspark.sql.streaming.proto.StateMessage_pb2", _globals +) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 5058 - _globals["_HANDLESTATE"]._serialized_end = 5154 - _globals["_STATEREQUEST"]._serialized_start = 71 - _globals["_STATEREQUEST"]._serialized_end = 518 - _globals["_STATERESPONSE"]._serialized_start = 520 - _globals["_STATERESPONSE"]._serialized_end = 592 - _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 594 - _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 681 - _globals["_STATEFULPROCESSORCALL"]._serialized_start = 684 - _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1266 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1269 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1565 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1568 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1792 - _globals["_TIMERREQUEST"]._serialized_start = 1795 - _globals["_TIMERREQUEST"]._serialized_end = 2013 - _globals["_TIMERVALUEREQUEST"]._serialized_start = 2016 - _globals["_TIMERVALUEREQUEST"]._serialized_end = 2228 - _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2230 - _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2277 - _globals["_GETPROCESSINGTIME"]._serialized_start = 2279 - _globals["_GETPROCESSINGTIME"]._serialized_end = 2298 - _globals["_GETWATERMARK"]._serialized_start = 2300 - _globals["_GETWATERMARK"]._serialized_end = 2314 - _globals["_STATECALLCOMMAND"]._serialized_start = 2317 - _globals["_STATECALLCOMMAND"]._serialized_end = 2471 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2474 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 2745 - _globals["_VALUESTATECALL"]._serialized_start = 2748 - _globals["_VALUESTATECALL"]._serialized_end = 3101 - _globals["_LISTSTATECALL"]._serialized_start = 3104 - _globals["_LISTSTATECALL"]._serialized_end = 3632 - _globals["_MAPSTATECALL"]._serialized_start = 3635 - _globals["_MAPSTATECALL"]._serialized_end = 4372 - _globals["_SETIMPLICITKEY"]._serialized_start = 4374 - _globals["_SETIMPLICITKEY"]._serialized_end = 4403 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 4405 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 4424 - _globals["_EXISTS"]._serialized_start = 4426 - _globals["_EXISTS"]._serialized_end = 4434 - _globals["_GET"]._serialized_start = 4436 - _globals["_GET"]._serialized_end = 4441 - _globals["_REGISTERTIMER"]._serialized_start = 4443 - _globals["_REGISTERTIMER"]._serialized_end = 4485 - _globals["_DELETETIMER"]._serialized_start = 4487 - _globals["_DELETETIMER"]._serialized_end = 4527 - _globals["_LISTTIMERS"]._serialized_start = 4529 - _globals["_LISTTIMERS"]._serialized_end = 4561 - _globals["_VALUESTATEUPDATE"]._serialized_start = 4563 - _globals["_VALUESTATEUPDATE"]._serialized_end = 4596 - _globals["_CLEAR"]._serialized_start = 4598 - _globals["_CLEAR"]._serialized_end = 4605 - _globals["_LISTSTATEGET"]._serialized_start = 4607 - _globals["_LISTSTATEGET"]._serialized_end = 4641 - _globals["_LISTSTATEPUT"]._serialized_start = 4643 - _globals["_LISTSTATEPUT"]._serialized_end = 4657 - _globals["_APPENDVALUE"]._serialized_start = 4659 - _globals["_APPENDVALUE"]._serialized_end = 4687 - _globals["_APPENDLIST"]._serialized_start = 4689 - _globals["_APPENDLIST"]._serialized_end = 4701 - _globals["_GETVALUE"]._serialized_start = 4703 - _globals["_GETVALUE"]._serialized_end = 4730 - _globals["_CONTAINSKEY"]._serialized_start = 4732 - _globals["_CONTAINSKEY"]._serialized_end = 4762 - _globals["_UPDATEVALUE"]._serialized_start = 4764 - _globals["_UPDATEVALUE"]._serialized_end = 4809 - _globals["_ITERATOR"]._serialized_start = 4811 - _globals["_ITERATOR"]._serialized_end = 4841 - _globals["_KEYS"]._serialized_start = 4843 - _globals["_KEYS"]._serialized_end = 4869 - _globals["_VALUES"]._serialized_start = 4871 - _globals["_VALUES"]._serialized_end = 4899 - _globals["_REMOVEKEY"]._serialized_start = 4901 - _globals["_REMOVEKEY"]._serialized_end = 4929 - _globals["_SETHANDLESTATE"]._serialized_start = 4931 - _globals["_SETHANDLESTATE"]._serialized_end = 5023 - _globals["_TTLCONFIG"]._serialized_start = 5025 - _globals["_TTLCONFIG"]._serialized_end = 5056 + _globals["_HANDLESTATE"]._serialized_start = 5997 + _globals["_HANDLESTATE"]._serialized_end = 6093 + _globals["_STATEREQUEST"]._serialized_start = 112 + _globals["_STATEREQUEST"]._serialized_end = 656 + _globals["_STATERESPONSE"]._serialized_start = 658 + _globals["_STATERESPONSE"]._serialized_end = 763 + _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 765 + _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 885 + _globals["_STATEFULPROCESSORCALL"]._serialized_start = 888 + _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1560 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1563 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1904 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1907 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2166 + _globals["_TIMERREQUEST"]._serialized_start = 2169 + _globals["_TIMERREQUEST"]._serialized_end = 2426 + _globals["_TIMERVALUEREQUEST"]._serialized_start = 2429 + _globals["_TIMERVALUEREQUEST"]._serialized_end = 2675 + _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2677 + _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2743 + _globals["_GETPROCESSINGTIME"]._serialized_start = 2745 + _globals["_GETPROCESSINGTIME"]._serialized_end = 2764 + _globals["_GETWATERMARK"]._serialized_start = 2766 + _globals["_GETWATERMARK"]._serialized_end = 2780 + _globals["_STATECALLCOMMAND"]._serialized_start = 2783 + _globals["_STATECALLCOMMAND"]._serialized_end = 2982 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2985 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3280 + _globals["_VALUESTATECALL"]._serialized_start = 3283 + _globals["_VALUESTATECALL"]._serialized_end = 3685 + _globals["_LISTSTATECALL"]._serialized_start = 3688 + _globals["_LISTSTATECALL"]._serialized_end = 4295 + _globals["_MAPSTATECALL"]._serialized_start = 4298 + _globals["_MAPSTATECALL"]._serialized_end = 5132 + _globals["_SETIMPLICITKEY"]._serialized_start = 5134 + _globals["_SETIMPLICITKEY"]._serialized_end = 5168 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5170 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5189 + _globals["_EXISTS"]._serialized_start = 5191 + _globals["_EXISTS"]._serialized_end = 5199 + _globals["_GET"]._serialized_start = 5201 + _globals["_GET"]._serialized_end = 5206 + _globals["_REGISTERTIMER"]._serialized_start = 5208 + _globals["_REGISTERTIMER"]._serialized_end = 5269 + _globals["_DELETETIMER"]._serialized_start = 5271 + _globals["_DELETETIMER"]._serialized_end = 5330 + _globals["_LISTTIMERS"]._serialized_start = 5332 + _globals["_LISTTIMERS"]._serialized_end = 5376 + _globals["_VALUESTATEUPDATE"]._serialized_start = 5378 + _globals["_VALUESTATEUPDATE"]._serialized_end = 5418 + _globals["_CLEAR"]._serialized_start = 5420 + _globals["_CLEAR"]._serialized_end = 5427 + _globals["_LISTSTATEGET"]._serialized_start = 5429 + _globals["_LISTSTATEGET"]._serialized_end = 5475 + _globals["_LISTSTATEPUT"]._serialized_start = 5477 + _globals["_LISTSTATEPUT"]._serialized_end = 5491 + _globals["_APPENDVALUE"]._serialized_start = 5493 + _globals["_APPENDVALUE"]._serialized_end = 5528 + _globals["_APPENDLIST"]._serialized_start = 5530 + _globals["_APPENDLIST"]._serialized_end = 5542 + _globals["_GETVALUE"]._serialized_start = 5544 + _globals["_GETVALUE"]._serialized_end = 5580 + _globals["_CONTAINSKEY"]._serialized_start = 5582 + _globals["_CONTAINSKEY"]._serialized_end = 5621 + _globals["_UPDATEVALUE"]._serialized_start = 5623 + _globals["_UPDATEVALUE"]._serialized_end = 5684 + _globals["_ITERATOR"]._serialized_start = 5686 + _globals["_ITERATOR"]._serialized_end = 5728 + _globals["_KEYS"]._serialized_start = 5730 + _globals["_KEYS"]._serialized_end = 5768 + _globals["_VALUES"]._serialized_start = 5770 + _globals["_VALUES"]._serialized_end = 5810 + _globals["_REMOVEKEY"]._serialized_start = 5812 + _globals["_REMOVEKEY"]._serialized_end = 5849 + _globals["_SETHANDLESTATE"]._serialized_start = 5851 + _globals["_SETHANDLESTATE"]._serialized_end = 5950 + _globals["_TTLCONFIG"]._serialized_start = 5952 + _globals["_TTLCONFIG"]._serialized_end = 5995 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi index bc5138f52281c..52f66928294cb 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi @@ -14,439 +14,1119 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ( - ClassVar as _ClassVar, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) - -DESCRIPTOR: _descriptor.FileDescriptor - -class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () - CREATED: _ClassVar[HandleState] - INITIALIZED: _ClassVar[HandleState] - DATA_PROCESSED: _ClassVar[HandleState] - TIMER_PROCESSED: _ClassVar[HandleState] - CLOSED: _ClassVar[HandleState] - -CREATED: HandleState -INITIALIZED: HandleState -DATA_PROCESSED: HandleState -TIMER_PROCESSED: HandleState -CLOSED: HandleState - -class StateRequest(_message.Message): - __slots__ = ( - "version", - "statefulProcessorCall", - "stateVariableRequest", - "implicitGroupingKeyRequest", - "timerRequest", - ) - VERSION_FIELD_NUMBER: _ClassVar[int] - STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int] - STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int] - IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int] - TIMERREQUEST_FIELD_NUMBER: _ClassVar[int] - version: int - statefulProcessorCall: StatefulProcessorCall - stateVariableRequest: StateVariableRequest - implicitGroupingKeyRequest: ImplicitGroupingKeyRequest - timerRequest: TimerRequest - def __init__( - self, - version: _Optional[int] = ..., - statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ..., - stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ..., - implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ..., - timerRequest: _Optional[_Union[TimerRequest, _Mapping]] = ..., - ) -> None: ... - -class StateResponse(_message.Message): - __slots__ = ("statusCode", "errorMessage", "value") - STATUSCODE_FIELD_NUMBER: _ClassVar[int] - ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - statusCode: int - errorMessage: str - value: bytes - def __init__( - self, - statusCode: _Optional[int] = ..., - errorMessage: _Optional[str] = ..., - value: _Optional[bytes] = ..., - ) -> None: ... - -class StateResponseWithLongTypeVal(_message.Message): - __slots__ = ("statusCode", "errorMessage", "value") - STATUSCODE_FIELD_NUMBER: _ClassVar[int] - ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - statusCode: int - errorMessage: str - value: int - def __init__( - self, - statusCode: _Optional[int] = ..., - errorMessage: _Optional[str] = ..., - value: _Optional[int] = ..., - ) -> None: ... - -class StatefulProcessorCall(_message.Message): - __slots__ = ( - "setHandleState", - "getValueState", - "getListState", - "getMapState", - "timerStateCall", - "deleteIfExists", - ) - SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] - GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] - GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] - GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] - TIMERSTATECALL_FIELD_NUMBER: _ClassVar[int] - DELETEIFEXISTS_FIELD_NUMBER: _ClassVar[int] - setHandleState: SetHandleState - getValueState: StateCallCommand - getListState: StateCallCommand - getMapState: StateCallCommand - timerStateCall: TimerStateCallCommand - deleteIfExists: StateCallCommand - def __init__( - self, - setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., - getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - timerStateCall: _Optional[_Union[TimerStateCallCommand, _Mapping]] = ..., - deleteIfExists: _Optional[_Union[StateCallCommand, _Mapping]] = ..., - ) -> None: ... - -class StateVariableRequest(_message.Message): - __slots__ = ("valueStateCall", "listStateCall", "mapStateCall") - VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] - LISTSTATECALL_FIELD_NUMBER: _ClassVar[int] - MAPSTATECALL_FIELD_NUMBER: _ClassVar[int] - valueStateCall: ValueStateCall - listStateCall: ListStateCall - mapStateCall: MapStateCall - def __init__( - self, - valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ..., - listStateCall: _Optional[_Union[ListStateCall, _Mapping]] = ..., - mapStateCall: _Optional[_Union[MapStateCall, _Mapping]] = ..., - ) -> None: ... - -class ImplicitGroupingKeyRequest(_message.Message): - __slots__ = ("setImplicitKey", "removeImplicitKey") - SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] - REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] - setImplicitKey: SetImplicitKey - removeImplicitKey: RemoveImplicitKey - def __init__( - self, - setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ..., - removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ..., - ) -> None: ... - -class TimerRequest(_message.Message): - __slots__ = ("timerValueRequest", "expiryTimerRequest") - TIMERVALUEREQUEST_FIELD_NUMBER: _ClassVar[int] - EXPIRYTIMERREQUEST_FIELD_NUMBER: _ClassVar[int] - timerValueRequest: TimerValueRequest - expiryTimerRequest: ExpiryTimerRequest - def __init__( - self, - timerValueRequest: _Optional[_Union[TimerValueRequest, _Mapping]] = ..., - expiryTimerRequest: _Optional[_Union[ExpiryTimerRequest, _Mapping]] = ..., - ) -> None: ... - -class TimerValueRequest(_message.Message): - __slots__ = ("getProcessingTimer", "getWatermark") - GETPROCESSINGTIMER_FIELD_NUMBER: _ClassVar[int] - GETWATERMARK_FIELD_NUMBER: _ClassVar[int] - getProcessingTimer: GetProcessingTime - getWatermark: GetWatermark - def __init__( - self, - getProcessingTimer: _Optional[_Union[GetProcessingTime, _Mapping]] = ..., - getWatermark: _Optional[_Union[GetWatermark, _Mapping]] = ..., - ) -> None: ... - -class ExpiryTimerRequest(_message.Message): - __slots__ = ("expiryTimestampMs",) - EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int] - expiryTimestampMs: int - def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ... - -class GetProcessingTime(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetWatermark(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class StateCallCommand(_message.Message): - __slots__ = ("stateName", "schema", "mapStateValueSchema", "ttl") - STATENAME_FIELD_NUMBER: _ClassVar[int] - SCHEMA_FIELD_NUMBER: _ClassVar[int] - MAPSTATEVALUESCHEMA_FIELD_NUMBER: _ClassVar[int] - TTL_FIELD_NUMBER: _ClassVar[int] - stateName: str - schema: str - mapStateValueSchema: str - ttl: TTLConfig - def __init__( - self, - stateName: _Optional[str] = ..., - schema: _Optional[str] = ..., - mapStateValueSchema: _Optional[str] = ..., - ttl: _Optional[_Union[TTLConfig, _Mapping]] = ..., - ) -> None: ... - -class TimerStateCallCommand(_message.Message): - __slots__ = ("register", "delete", "list") - REGISTER_FIELD_NUMBER: _ClassVar[int] - DELETE_FIELD_NUMBER: _ClassVar[int] - LIST_FIELD_NUMBER: _ClassVar[int] - register: RegisterTimer - delete: DeleteTimer - list: ListTimers - def __init__( - self, - register: _Optional[_Union[RegisterTimer, _Mapping]] = ..., - delete: _Optional[_Union[DeleteTimer, _Mapping]] = ..., - list: _Optional[_Union[ListTimers, _Mapping]] = ..., - ) -> None: ... - -class ValueStateCall(_message.Message): - __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear") - STATENAME_FIELD_NUMBER: _ClassVar[int] - EXISTS_FIELD_NUMBER: _ClassVar[int] - GET_FIELD_NUMBER: _ClassVar[int] - VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int] - CLEAR_FIELD_NUMBER: _ClassVar[int] - stateName: str - exists: Exists - get: Get - valueStateUpdate: ValueStateUpdate - clear: Clear - def __init__( - self, - stateName: _Optional[str] = ..., - exists: _Optional[_Union[Exists, _Mapping]] = ..., - get: _Optional[_Union[Get, _Mapping]] = ..., - valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ..., - clear: _Optional[_Union[Clear, _Mapping]] = ..., - ) -> None: ... - -class ListStateCall(_message.Message): - __slots__ = ( - "stateName", - "exists", - "listStateGet", - "listStatePut", - "appendValue", - "appendList", - "clear", - ) - STATENAME_FIELD_NUMBER: _ClassVar[int] - EXISTS_FIELD_NUMBER: _ClassVar[int] - LISTSTATEGET_FIELD_NUMBER: _ClassVar[int] - LISTSTATEPUT_FIELD_NUMBER: _ClassVar[int] - APPENDVALUE_FIELD_NUMBER: _ClassVar[int] - APPENDLIST_FIELD_NUMBER: _ClassVar[int] - CLEAR_FIELD_NUMBER: _ClassVar[int] - stateName: str - exists: Exists - listStateGet: ListStateGet - listStatePut: ListStatePut - appendValue: AppendValue - appendList: AppendList - clear: Clear - def __init__( - self, - stateName: _Optional[str] = ..., - exists: _Optional[_Union[Exists, _Mapping]] = ..., - listStateGet: _Optional[_Union[ListStateGet, _Mapping]] = ..., - listStatePut: _Optional[_Union[ListStatePut, _Mapping]] = ..., - appendValue: _Optional[_Union[AppendValue, _Mapping]] = ..., - appendList: _Optional[_Union[AppendList, _Mapping]] = ..., - clear: _Optional[_Union[Clear, _Mapping]] = ..., - ) -> None: ... - -class MapStateCall(_message.Message): - __slots__ = ( - "stateName", - "exists", - "getValue", - "containsKey", - "updateValue", - "iterator", - "keys", - "values", - "removeKey", - "clear", - ) - STATENAME_FIELD_NUMBER: _ClassVar[int] - EXISTS_FIELD_NUMBER: _ClassVar[int] - GETVALUE_FIELD_NUMBER: _ClassVar[int] - CONTAINSKEY_FIELD_NUMBER: _ClassVar[int] - UPDATEVALUE_FIELD_NUMBER: _ClassVar[int] - ITERATOR_FIELD_NUMBER: _ClassVar[int] - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUES_FIELD_NUMBER: _ClassVar[int] - REMOVEKEY_FIELD_NUMBER: _ClassVar[int] - CLEAR_FIELD_NUMBER: _ClassVar[int] - stateName: str - exists: Exists - getValue: GetValue - containsKey: ContainsKey - updateValue: UpdateValue - iterator: Iterator - keys: Keys - values: Values - removeKey: RemoveKey - clear: Clear - def __init__( - self, - stateName: _Optional[str] = ..., - exists: _Optional[_Union[Exists, _Mapping]] = ..., - getValue: _Optional[_Union[GetValue, _Mapping]] = ..., - containsKey: _Optional[_Union[ContainsKey, _Mapping]] = ..., - updateValue: _Optional[_Union[UpdateValue, _Mapping]] = ..., - iterator: _Optional[_Union[Iterator, _Mapping]] = ..., - keys: _Optional[_Union[Keys, _Mapping]] = ..., - values: _Optional[_Union[Values, _Mapping]] = ..., - removeKey: _Optional[_Union[RemoveKey, _Mapping]] = ..., - clear: _Optional[_Union[Clear, _Mapping]] = ..., - ) -> None: ... - -class SetImplicitKey(_message.Message): - __slots__ = ("key",) - KEY_FIELD_NUMBER: _ClassVar[int] - key: bytes - def __init__(self, key: _Optional[bytes] = ...) -> None: ... - -class RemoveImplicitKey(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Exists(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Get(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class RegisterTimer(_message.Message): - __slots__ = ("expiryTimestampMs",) - EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int] - expiryTimestampMs: int - def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ... - -class DeleteTimer(_message.Message): - __slots__ = ("expiryTimestampMs",) - EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int] - expiryTimestampMs: int - def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ... - -class ListTimers(_message.Message): - __slots__ = ("iteratorId",) - ITERATORID_FIELD_NUMBER: _ClassVar[int] - iteratorId: str - def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... - -class ValueStateUpdate(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] - value: bytes - def __init__(self, value: _Optional[bytes] = ...) -> None: ... - -class Clear(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ListStateGet(_message.Message): - __slots__ = ("iteratorId",) - ITERATORID_FIELD_NUMBER: _ClassVar[int] - iteratorId: str - def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... - -class ListStatePut(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class AppendValue(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] - value: bytes - def __init__(self, value: _Optional[bytes] = ...) -> None: ... - -class AppendList(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetValue(_message.Message): - __slots__ = ("userKey",) - USERKEY_FIELD_NUMBER: _ClassVar[int] - userKey: bytes - def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... - -class ContainsKey(_message.Message): - __slots__ = ("userKey",) - USERKEY_FIELD_NUMBER: _ClassVar[int] - userKey: bytes - def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... - -class UpdateValue(_message.Message): - __slots__ = ("userKey", "value") - USERKEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - userKey: bytes - value: bytes - def __init__(self, userKey: _Optional[bytes] = ..., value: _Optional[bytes] = ...) -> None: ... - -class Iterator(_message.Message): - __slots__ = ("iteratorId",) - ITERATORID_FIELD_NUMBER: _ClassVar[int] - iteratorId: str - def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... - -class Keys(_message.Message): - __slots__ = ("iteratorId",) - ITERATORID_FIELD_NUMBER: _ClassVar[int] - iteratorId: str - def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... - -class Values(_message.Message): - __slots__ = ("iteratorId",) - ITERATORID_FIELD_NUMBER: _ClassVar[int] - iteratorId: str - def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... - -class RemoveKey(_message.Message): - __slots__ = ("userKey",) - USERKEY_FIELD_NUMBER: _ClassVar[int] - userKey: bytes - def __init__(self, userKey: _Optional[bytes] = ...) -> None: ... - -class SetHandleState(_message.Message): - __slots__ = ("state",) - STATE_FIELD_NUMBER: _ClassVar[int] - state: HandleState - def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ... - -class TTLConfig(_message.Message): - __slots__ = ("durationMs",) - DURATIONMS_FIELD_NUMBER: _ClassVar[int] - durationMs: int - def __init__(self, durationMs: _Optional[int] = ...) -> None: ... +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file + +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _HandleState: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _HandleStateEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_HandleState.ValueType], + builtins.type, +): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + CREATED: _HandleState.ValueType # 0 + INITIALIZED: _HandleState.ValueType # 1 + DATA_PROCESSED: _HandleState.ValueType # 2 + TIMER_PROCESSED: _HandleState.ValueType # 3 + CLOSED: _HandleState.ValueType # 4 + +class HandleState(_HandleState, metaclass=_HandleStateEnumTypeWrapper): ... + +CREATED: HandleState.ValueType # 0 +INITIALIZED: HandleState.ValueType # 1 +DATA_PROCESSED: HandleState.ValueType # 2 +TIMER_PROCESSED: HandleState.ValueType # 3 +CLOSED: HandleState.ValueType # 4 +global___HandleState = HandleState + +class StateRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VERSION_FIELD_NUMBER: builtins.int + STATEFULPROCESSORCALL_FIELD_NUMBER: builtins.int + STATEVARIABLEREQUEST_FIELD_NUMBER: builtins.int + IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: builtins.int + TIMERREQUEST_FIELD_NUMBER: builtins.int + version: builtins.int + @property + def statefulProcessorCall(self) -> global___StatefulProcessorCall: ... + @property + def stateVariableRequest(self) -> global___StateVariableRequest: ... + @property + def implicitGroupingKeyRequest(self) -> global___ImplicitGroupingKeyRequest: ... + @property + def timerRequest(self) -> global___TimerRequest: ... + def __init__( + self, + *, + version: builtins.int = ..., + statefulProcessorCall: global___StatefulProcessorCall | None = ..., + stateVariableRequest: global___StateVariableRequest | None = ..., + implicitGroupingKeyRequest: global___ImplicitGroupingKeyRequest | None = ..., + timerRequest: global___TimerRequest | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "implicitGroupingKeyRequest", + b"implicitGroupingKeyRequest", + "method", + b"method", + "stateVariableRequest", + b"stateVariableRequest", + "statefulProcessorCall", + b"statefulProcessorCall", + "timerRequest", + b"timerRequest", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "implicitGroupingKeyRequest", + b"implicitGroupingKeyRequest", + "method", + b"method", + "stateVariableRequest", + b"stateVariableRequest", + "statefulProcessorCall", + b"statefulProcessorCall", + "timerRequest", + b"timerRequest", + "version", + b"version", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> ( + typing_extensions.Literal[ + "statefulProcessorCall", + "stateVariableRequest", + "implicitGroupingKeyRequest", + "timerRequest", + ] + | None + ): ... + +global___StateRequest = StateRequest + +class StateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUSCODE_FIELD_NUMBER: builtins.int + ERRORMESSAGE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + statusCode: builtins.int + errorMessage: builtins.str + value: builtins.bytes + def __init__( + self, + *, + statusCode: builtins.int = ..., + errorMessage: builtins.str = ..., + value: builtins.bytes = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "errorMessage", b"errorMessage", "statusCode", b"statusCode", "value", b"value" + ], + ) -> None: ... + +global___StateResponse = StateResponse + +class StateResponseWithLongTypeVal(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUSCODE_FIELD_NUMBER: builtins.int + ERRORMESSAGE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + statusCode: builtins.int + errorMessage: builtins.str + value: builtins.int + def __init__( + self, + *, + statusCode: builtins.int = ..., + errorMessage: builtins.str = ..., + value: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "errorMessage", b"errorMessage", "statusCode", b"statusCode", "value", b"value" + ], + ) -> None: ... + +global___StateResponseWithLongTypeVal = StateResponseWithLongTypeVal + +class StatefulProcessorCall(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SETHANDLESTATE_FIELD_NUMBER: builtins.int + GETVALUESTATE_FIELD_NUMBER: builtins.int + GETLISTSTATE_FIELD_NUMBER: builtins.int + GETMAPSTATE_FIELD_NUMBER: builtins.int + TIMERSTATECALL_FIELD_NUMBER: builtins.int + DELETEIFEXISTS_FIELD_NUMBER: builtins.int + @property + def setHandleState(self) -> global___SetHandleState: ... + @property + def getValueState(self) -> global___StateCallCommand: ... + @property + def getListState(self) -> global___StateCallCommand: ... + @property + def getMapState(self) -> global___StateCallCommand: ... + @property + def timerStateCall(self) -> global___TimerStateCallCommand: ... + @property + def deleteIfExists(self) -> global___StateCallCommand: ... + def __init__( + self, + *, + setHandleState: global___SetHandleState | None = ..., + getValueState: global___StateCallCommand | None = ..., + getListState: global___StateCallCommand | None = ..., + getMapState: global___StateCallCommand | None = ..., + timerStateCall: global___TimerStateCallCommand | None = ..., + deleteIfExists: global___StateCallCommand | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "deleteIfExists", + b"deleteIfExists", + "getListState", + b"getListState", + "getMapState", + b"getMapState", + "getValueState", + b"getValueState", + "method", + b"method", + "setHandleState", + b"setHandleState", + "timerStateCall", + b"timerStateCall", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "deleteIfExists", + b"deleteIfExists", + "getListState", + b"getListState", + "getMapState", + b"getMapState", + "getValueState", + b"getValueState", + "method", + b"method", + "setHandleState", + b"setHandleState", + "timerStateCall", + b"timerStateCall", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> ( + typing_extensions.Literal[ + "setHandleState", + "getValueState", + "getListState", + "getMapState", + "timerStateCall", + "deleteIfExists", + ] + | None + ): ... + +global___StatefulProcessorCall = StatefulProcessorCall + +class StateVariableRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VALUESTATECALL_FIELD_NUMBER: builtins.int + LISTSTATECALL_FIELD_NUMBER: builtins.int + MAPSTATECALL_FIELD_NUMBER: builtins.int + @property + def valueStateCall(self) -> global___ValueStateCall: ... + @property + def listStateCall(self) -> global___ListStateCall: ... + @property + def mapStateCall(self) -> global___MapStateCall: ... + def __init__( + self, + *, + valueStateCall: global___ValueStateCall | None = ..., + listStateCall: global___ListStateCall | None = ..., + mapStateCall: global___MapStateCall | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "listStateCall", + b"listStateCall", + "mapStateCall", + b"mapStateCall", + "method", + b"method", + "valueStateCall", + b"valueStateCall", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "listStateCall", + b"listStateCall", + "mapStateCall", + b"mapStateCall", + "method", + b"method", + "valueStateCall", + b"valueStateCall", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["valueStateCall", "listStateCall", "mapStateCall"] | None: ... + +global___StateVariableRequest = StateVariableRequest + +class ImplicitGroupingKeyRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SETIMPLICITKEY_FIELD_NUMBER: builtins.int + REMOVEIMPLICITKEY_FIELD_NUMBER: builtins.int + @property + def setImplicitKey(self) -> global___SetImplicitKey: ... + @property + def removeImplicitKey(self) -> global___RemoveImplicitKey: ... + def __init__( + self, + *, + setImplicitKey: global___SetImplicitKey | None = ..., + removeImplicitKey: global___RemoveImplicitKey | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "method", + b"method", + "removeImplicitKey", + b"removeImplicitKey", + "setImplicitKey", + b"setImplicitKey", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "method", + b"method", + "removeImplicitKey", + b"removeImplicitKey", + "setImplicitKey", + b"setImplicitKey", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["setImplicitKey", "removeImplicitKey"] | None: ... + +global___ImplicitGroupingKeyRequest = ImplicitGroupingKeyRequest + +class TimerRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TIMERVALUEREQUEST_FIELD_NUMBER: builtins.int + EXPIRYTIMERREQUEST_FIELD_NUMBER: builtins.int + @property + def timerValueRequest(self) -> global___TimerValueRequest: ... + @property + def expiryTimerRequest(self) -> global___ExpiryTimerRequest: ... + def __init__( + self, + *, + timerValueRequest: global___TimerValueRequest | None = ..., + expiryTimerRequest: global___ExpiryTimerRequest | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "expiryTimerRequest", + b"expiryTimerRequest", + "method", + b"method", + "timerValueRequest", + b"timerValueRequest", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "expiryTimerRequest", + b"expiryTimerRequest", + "method", + b"method", + "timerValueRequest", + b"timerValueRequest", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["timerValueRequest", "expiryTimerRequest"] | None: ... + +global___TimerRequest = TimerRequest + +class TimerValueRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + GETPROCESSINGTIMER_FIELD_NUMBER: builtins.int + GETWATERMARK_FIELD_NUMBER: builtins.int + @property + def getProcessingTimer(self) -> global___GetProcessingTime: ... + @property + def getWatermark(self) -> global___GetWatermark: ... + def __init__( + self, + *, + getProcessingTimer: global___GetProcessingTime | None = ..., + getWatermark: global___GetWatermark | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "getProcessingTimer", + b"getProcessingTimer", + "getWatermark", + b"getWatermark", + "method", + b"method", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "getProcessingTimer", + b"getProcessingTimer", + "getWatermark", + b"getWatermark", + "method", + b"method", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["getProcessingTimer", "getWatermark"] | None: ... + +global___TimerValueRequest = TimerValueRequest + +class ExpiryTimerRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXPIRYTIMESTAMPMS_FIELD_NUMBER: builtins.int + expiryTimestampMs: builtins.int + def __init__( + self, + *, + expiryTimestampMs: builtins.int = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["expiryTimestampMs", b"expiryTimestampMs"] + ) -> None: ... + +global___ExpiryTimerRequest = ExpiryTimerRequest + +class GetProcessingTime(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetProcessingTime = GetProcessingTime + +class GetWatermark(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___GetWatermark = GetWatermark + +class StateCallCommand(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATENAME_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int + MAPSTATEVALUESCHEMA_FIELD_NUMBER: builtins.int + TTL_FIELD_NUMBER: builtins.int + stateName: builtins.str + schema: builtins.str + mapStateValueSchema: builtins.str + @property + def ttl(self) -> global___TTLConfig: ... + def __init__( + self, + *, + stateName: builtins.str = ..., + schema: builtins.str = ..., + mapStateValueSchema: builtins.str = ..., + ttl: global___TTLConfig | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["ttl", b"ttl"]) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "mapStateValueSchema", + b"mapStateValueSchema", + "schema", + b"schema", + "stateName", + b"stateName", + "ttl", + b"ttl", + ], + ) -> None: ... + +global___StateCallCommand = StateCallCommand + +class TimerStateCallCommand(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + REGISTER_FIELD_NUMBER: builtins.int + DELETE_FIELD_NUMBER: builtins.int + LIST_FIELD_NUMBER: builtins.int + @property + def register(self) -> global___RegisterTimer: ... + @property + def delete(self) -> global___DeleteTimer: ... + @property + def list(self) -> global___ListTimers: ... + def __init__( + self, + *, + register: global___RegisterTimer | None = ..., + delete: global___DeleteTimer | None = ..., + list: global___ListTimers | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "delete", b"delete", "list", b"list", "method", b"method", "register", b"register" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "delete", b"delete", "list", b"list", "method", b"method", "register", b"register" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["register", "delete", "list"] | None: ... + +global___TimerStateCallCommand = TimerStateCallCommand + +class ValueStateCall(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATENAME_FIELD_NUMBER: builtins.int + EXISTS_FIELD_NUMBER: builtins.int + GET_FIELD_NUMBER: builtins.int + VALUESTATEUPDATE_FIELD_NUMBER: builtins.int + CLEAR_FIELD_NUMBER: builtins.int + stateName: builtins.str + @property + def exists(self) -> global___Exists: ... + @property + def get(self) -> global___Get: ... + @property + def valueStateUpdate(self) -> global___ValueStateUpdate: ... + @property + def clear(self) -> global___Clear: ... + def __init__( + self, + *, + stateName: builtins.str = ..., + exists: global___Exists | None = ..., + get: global___Get | None = ..., + valueStateUpdate: global___ValueStateUpdate | None = ..., + clear: global___Clear | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "clear", + b"clear", + "exists", + b"exists", + "get", + b"get", + "method", + b"method", + "valueStateUpdate", + b"valueStateUpdate", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "clear", + b"clear", + "exists", + b"exists", + "get", + b"get", + "method", + b"method", + "stateName", + b"stateName", + "valueStateUpdate", + b"valueStateUpdate", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["exists", "get", "valueStateUpdate", "clear"] | None: ... + +global___ValueStateCall = ValueStateCall + +class ListStateCall(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATENAME_FIELD_NUMBER: builtins.int + EXISTS_FIELD_NUMBER: builtins.int + LISTSTATEGET_FIELD_NUMBER: builtins.int + LISTSTATEPUT_FIELD_NUMBER: builtins.int + APPENDVALUE_FIELD_NUMBER: builtins.int + APPENDLIST_FIELD_NUMBER: builtins.int + CLEAR_FIELD_NUMBER: builtins.int + stateName: builtins.str + @property + def exists(self) -> global___Exists: ... + @property + def listStateGet(self) -> global___ListStateGet: ... + @property + def listStatePut(self) -> global___ListStatePut: ... + @property + def appendValue(self) -> global___AppendValue: ... + @property + def appendList(self) -> global___AppendList: ... + @property + def clear(self) -> global___Clear: ... + def __init__( + self, + *, + stateName: builtins.str = ..., + exists: global___Exists | None = ..., + listStateGet: global___ListStateGet | None = ..., + listStatePut: global___ListStatePut | None = ..., + appendValue: global___AppendValue | None = ..., + appendList: global___AppendList | None = ..., + clear: global___Clear | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "appendList", + b"appendList", + "appendValue", + b"appendValue", + "clear", + b"clear", + "exists", + b"exists", + "listStateGet", + b"listStateGet", + "listStatePut", + b"listStatePut", + "method", + b"method", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "appendList", + b"appendList", + "appendValue", + b"appendValue", + "clear", + b"clear", + "exists", + b"exists", + "listStateGet", + b"listStateGet", + "listStatePut", + b"listStatePut", + "method", + b"method", + "stateName", + b"stateName", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> ( + typing_extensions.Literal[ + "exists", "listStateGet", "listStatePut", "appendValue", "appendList", "clear" + ] + | None + ): ... + +global___ListStateCall = ListStateCall + +class MapStateCall(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATENAME_FIELD_NUMBER: builtins.int + EXISTS_FIELD_NUMBER: builtins.int + GETVALUE_FIELD_NUMBER: builtins.int + CONTAINSKEY_FIELD_NUMBER: builtins.int + UPDATEVALUE_FIELD_NUMBER: builtins.int + ITERATOR_FIELD_NUMBER: builtins.int + KEYS_FIELD_NUMBER: builtins.int + VALUES_FIELD_NUMBER: builtins.int + REMOVEKEY_FIELD_NUMBER: builtins.int + CLEAR_FIELD_NUMBER: builtins.int + stateName: builtins.str + @property + def exists(self) -> global___Exists: ... + @property + def getValue(self) -> global___GetValue: ... + @property + def containsKey(self) -> global___ContainsKey: ... + @property + def updateValue(self) -> global___UpdateValue: ... + @property + def iterator(self) -> global___Iterator: ... + @property + def keys(self) -> global___Keys: ... + @property + def values(self) -> global___Values: ... + @property + def removeKey(self) -> global___RemoveKey: ... + @property + def clear(self) -> global___Clear: ... + def __init__( + self, + *, + stateName: builtins.str = ..., + exists: global___Exists | None = ..., + getValue: global___GetValue | None = ..., + containsKey: global___ContainsKey | None = ..., + updateValue: global___UpdateValue | None = ..., + iterator: global___Iterator | None = ..., + keys: global___Keys | None = ..., + values: global___Values | None = ..., + removeKey: global___RemoveKey | None = ..., + clear: global___Clear | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "clear", + b"clear", + "containsKey", + b"containsKey", + "exists", + b"exists", + "getValue", + b"getValue", + "iterator", + b"iterator", + "keys", + b"keys", + "method", + b"method", + "removeKey", + b"removeKey", + "updateValue", + b"updateValue", + "values", + b"values", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "clear", + b"clear", + "containsKey", + b"containsKey", + "exists", + b"exists", + "getValue", + b"getValue", + "iterator", + b"iterator", + "keys", + b"keys", + "method", + b"method", + "removeKey", + b"removeKey", + "stateName", + b"stateName", + "updateValue", + b"updateValue", + "values", + b"values", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> ( + typing_extensions.Literal[ + "exists", + "getValue", + "containsKey", + "updateValue", + "iterator", + "keys", + "values", + "removeKey", + "clear", + ] + | None + ): ... + +global___MapStateCall = MapStateCall + +class SetImplicitKey(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + key: builtins.bytes + def __init__( + self, + *, + key: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key", b"key"]) -> None: ... + +global___SetImplicitKey = SetImplicitKey + +class RemoveImplicitKey(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___RemoveImplicitKey = RemoveImplicitKey + +class Exists(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___Exists = Exists + +class Get(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___Get = Get + +class RegisterTimer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXPIRYTIMESTAMPMS_FIELD_NUMBER: builtins.int + expiryTimestampMs: builtins.int + def __init__( + self, + *, + expiryTimestampMs: builtins.int = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["expiryTimestampMs", b"expiryTimestampMs"] + ) -> None: ... + +global___RegisterTimer = RegisterTimer + +class DeleteTimer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXPIRYTIMESTAMPMS_FIELD_NUMBER: builtins.int + expiryTimestampMs: builtins.int + def __init__( + self, + *, + expiryTimestampMs: builtins.int = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["expiryTimestampMs", b"expiryTimestampMs"] + ) -> None: ... + +global___DeleteTimer = DeleteTimer + +class ListTimers(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ITERATORID_FIELD_NUMBER: builtins.int + iteratorId: builtins.str + def __init__( + self, + *, + iteratorId: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["iteratorId", b"iteratorId"] + ) -> None: ... + +global___ListTimers = ListTimers + +class ValueStateUpdate(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VALUE_FIELD_NUMBER: builtins.int + value: builtins.bytes + def __init__( + self, + *, + value: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["value", b"value"]) -> None: ... + +global___ValueStateUpdate = ValueStateUpdate + +class Clear(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___Clear = Clear + +class ListStateGet(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ITERATORID_FIELD_NUMBER: builtins.int + iteratorId: builtins.str + def __init__( + self, + *, + iteratorId: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["iteratorId", b"iteratorId"] + ) -> None: ... + +global___ListStateGet = ListStateGet + +class ListStatePut(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___ListStatePut = ListStatePut + +class AppendValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + VALUE_FIELD_NUMBER: builtins.int + value: builtins.bytes + def __init__( + self, + *, + value: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["value", b"value"]) -> None: ... + +global___AppendValue = AppendValue + +class AppendList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___AppendList = AppendList + +class GetValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + USERKEY_FIELD_NUMBER: builtins.int + userKey: builtins.bytes + def __init__( + self, + *, + userKey: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["userKey", b"userKey"]) -> None: ... + +global___GetValue = GetValue + +class ContainsKey(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + USERKEY_FIELD_NUMBER: builtins.int + userKey: builtins.bytes + def __init__( + self, + *, + userKey: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["userKey", b"userKey"]) -> None: ... + +global___ContainsKey = ContainsKey + +class UpdateValue(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + USERKEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + userKey: builtins.bytes + value: builtins.bytes + def __init__( + self, + *, + userKey: builtins.bytes = ..., + value: builtins.bytes = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["userKey", b"userKey", "value", b"value"] + ) -> None: ... + +global___UpdateValue = UpdateValue + +class Iterator(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ITERATORID_FIELD_NUMBER: builtins.int + iteratorId: builtins.str + def __init__( + self, + *, + iteratorId: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["iteratorId", b"iteratorId"] + ) -> None: ... + +global___Iterator = Iterator + +class Keys(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ITERATORID_FIELD_NUMBER: builtins.int + iteratorId: builtins.str + def __init__( + self, + *, + iteratorId: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["iteratorId", b"iteratorId"] + ) -> None: ... + +global___Keys = Keys + +class Values(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ITERATORID_FIELD_NUMBER: builtins.int + iteratorId: builtins.str + def __init__( + self, + *, + iteratorId: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["iteratorId", b"iteratorId"] + ) -> None: ... + +global___Values = Values + +class RemoveKey(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + USERKEY_FIELD_NUMBER: builtins.int + userKey: builtins.bytes + def __init__( + self, + *, + userKey: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["userKey", b"userKey"]) -> None: ... + +global___RemoveKey = RemoveKey + +class SetHandleState(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATE_FIELD_NUMBER: builtins.int + state: global___HandleState.ValueType + def __init__( + self, + *, + state: global___HandleState.ValueType = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["state", b"state"]) -> None: ... + +global___SetHandleState = SetHandleState + +class TTLConfig(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DURATIONMS_FIELD_NUMBER: builtins.int + durationMs: builtins.int + def __init__( + self, + *, + durationMs: builtins.int = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["durationMs", b"durationMs"] + ) -> None: ... + +global___TTLConfig = TTLConfig diff --git a/sql/core/src/main/buf.gen.yaml b/sql/core/src/main/buf.gen.yaml new file mode 100644 index 0000000000000..94da50c2c41c8 --- /dev/null +++ b/sql/core/src/main/buf.gen.yaml @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +plugins: + # Building the Python build and building the mypy interfaces. + - plugin: buf.build/protocolbuffers/python:v28.3 + out: gen/proto/python + - name: mypy + out: gen/proto/python + diff --git a/sql/core/src/main/buf.work.yaml b/sql/core/src/main/buf.work.yaml new file mode 100644 index 0000000000000..a02dead420cdf --- /dev/null +++ b/sql/core/src/main/buf.work.yaml @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +directories: + - protobuf From 05508cf7cb9da3042fa4b17645102a6406278695 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 13 Nov 2024 20:07:52 +0100 Subject: [PATCH 29/39] [SPARK-42838][SQL] Assign a name to the error class _LEGACY_ERROR_TEMP_2000 ### What changes were proposed in this pull request? Introducing two new error classes instead of _LEGACY_ERROR_TEMP_2000. Classes introduced: - DATETIME_FIELD_OUT_OF_BOUNDS - INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION ### Why are the changes needed? We want to assign names for all existing error classes. ### Does this PR introduce _any_ user-facing change? Yes, error message changed. ### How was this patch tested? Existing tests cover error raising. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48332 from mihailom-db/invalid_date_argument_value. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 17 ++++++++++++----- .../expressions/datetimeExpressions.scala | 8 ++++---- .../spark/sql/catalyst/util/DateTimeUtils.scala | 3 +-- .../spark/sql/errors/QueryExecutionErrors.scala | 14 ++++++-------- .../expressions/DateExpressionsSuite.scala | 6 ++---- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 6 ++---- .../sql-tests/results/ansi/date.sql.out | 10 ++++++---- .../sql-tests/results/ansi/timestamp.sql.out | 15 +++++++++------ .../sql-tests/results/postgreSQL/date.sql.out | 15 +++++++++------ .../results/timestampNTZ/timestamp-ansi.sql.out | 15 +++++++++------ 10 files changed, 60 insertions(+), 49 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 5e1c3f46fd110..eb772f053a889 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1101,6 +1101,12 @@ ], "sqlState" : "42K03" }, + "DATETIME_FIELD_OUT_OF_BOUNDS" : { + "message" : [ + ". If necessary set to \"false\" to bypass this error." + ], + "sqlState" : "22023" + }, "DATETIME_OVERFLOW" : { "message" : [ "Datetime operation overflow: ." @@ -2609,6 +2615,12 @@ }, "sqlState" : "22006" }, + "INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION" : { + "message" : [ + "Cannot add an interval to a date because its microseconds part is not 0. If necessary set to \"false\" to bypass this error." + ], + "sqlState" : "22006" + }, "INVALID_INVERSE_DISTRIBUTION_FUNCTION" : { "message" : [ "Invalid inverse distribution function ." @@ -6905,11 +6917,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_2000" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2003" : { "message" : [ "Unsuccessful try to zip maps with unique keys due to exceeding the array size limit ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f2ba3ed95b850..fba3927a0bc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2507,14 +2507,14 @@ case class MakeDate( localDateToDays(ld) } catch { case e: java.time.DateTimeException => - if (failOnError) throw QueryExecutionErrors.ansiDateTimeError(e) else null + if (failOnError) throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e) else null } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.ansiDateTimeError(e);" + "throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);" } else { s"${ev.isNull} = true;" } @@ -2839,7 +2839,7 @@ case class MakeTimestamp( } catch { case e: SparkDateTimeException if failOnError => throw e case e: DateTimeException if failOnError => - throw QueryExecutionErrors.ansiDateTimeError(e) + throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e) case _: DateTimeException => null } } @@ -2870,7 +2870,7 @@ case class MakeTimestamp( val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val d = Decimal.getClass.getName.stripSuffix("$") val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.ansiDateTimeError(e);" + "throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);" } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index e27ce29fc2318..c9ca3ed864c16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -304,8 +304,7 @@ object DateTimeUtils extends SparkDateTimeUtils { start: Int, interval: CalendarInterval): Int = { if (interval.microseconds != 0) { - throw QueryExecutionErrors.ansiIllegalArgumentError( - "Cannot add hours, minutes or seconds, milliseconds, microseconds to a date") + throw QueryExecutionErrors.invalidIntervalWithMicrosecondsAdditionError() } val ld = daysToLocalDate(start).plusMonths(interval.months).plusDays(interval.days) localDateToDays(ld) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index fb39d3c5d7c6b..ba48000f2aeca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -277,22 +277,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = "") } - def ansiDateTimeError(e: Exception): SparkDateTimeException = { + def ansiDateTimeArgumentOutOfRange(e: Exception): SparkDateTimeException = { new SparkDateTimeException( - errorClass = "_LEGACY_ERROR_TEMP_2000", + errorClass = "DATETIME_FIELD_OUT_OF_BOUNDS", messageParameters = Map( - "message" -> e.getMessage, + "rangeMessage" -> e.getMessage, "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), context = Array.empty, summary = "") } - def ansiIllegalArgumentError(message: String): SparkIllegalArgumentException = { + def invalidIntervalWithMicrosecondsAdditionError(): SparkIllegalArgumentException = { new SparkIllegalArgumentException( - errorClass = "_LEGACY_ERROR_TEMP_2000", - messageParameters = Map( - "message" -> message, - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key))) + errorClass = "INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION", + messageParameters = Map("ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key))) } def overflowInSumOfDecimalError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 05d68504a7270..5cd974838fa24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -436,10 +436,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) { checkErrorInExpression[SparkIllegalArgumentException]( DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))), - "_LEGACY_ERROR_TEMP_2000", - Map("message" -> - "Cannot add hours, minutes or seconds, milliseconds, microseconds to a date", - "ansiConfig" -> "\"spark.sql.ansi.enabled\"")) + "INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION", + Map("ansiConfig" -> "\"spark.sql.ansi.enabled\"")) } withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 96aaf13052b02..790c834d83e97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -542,10 +542,8 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { checkError( exception = intercept[SparkIllegalArgumentException]( dateAddInterval(input, new CalendarInterval(36, 47, 1))), - condition = "_LEGACY_ERROR_TEMP_2000", - parameters = Map( - "message" -> "Cannot add hours, minutes or seconds, milliseconds, microseconds to a date", - "ansiConfig" -> "\"spark.sql.ansi.enabled\"")) + condition = "INVALID_INTERVAL_WITH_MICROSECONDS_ADDITION", + parameters = Map("ansiConfig" -> "\"spark.sql.ansi.enabled\"")) } test("timestamp add interval") { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index 67cd23faf2556..aa283d3249617 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -53,10 +53,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for MonthOfYear (valid values 1 - 12): 13" + "rangeMessage" : "Invalid value for MonthOfYear (valid values 1 - 12): 13" } } @@ -68,10 +69,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for DayOfMonth (valid values 1 - 28/31): 33" + "rangeMessage" : "Invalid value for DayOfMonth (valid values 1 - 28/31): 33" } } diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out index d75380b16cc83..e3cf1a1549228 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out @@ -154,10 +154,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 61" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 61" } } @@ -185,10 +186,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 99" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 99" } } @@ -200,10 +202,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 999" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 999" } } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out index 8caf8c54b9f39..d9f4301dd0e8d 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out @@ -687,10 +687,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid date 'FEBRUARY 30'" + "rangeMessage" : "Invalid date 'FEBRUARY 30'" } } @@ -702,10 +703,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for MonthOfYear (valid values 1 - 12): 13" + "rangeMessage" : "Invalid value for MonthOfYear (valid values 1 - 12): 13" } } @@ -717,10 +719,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for DayOfMonth (valid values 1 - 28/31): -1" + "rangeMessage" : "Invalid value for DayOfMonth (valid values 1 - 28/31): -1" } } diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index 79996d838c1e5..681306ba9f405 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -154,10 +154,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 61" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 61" } } @@ -185,10 +186,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 99" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 99" } } @@ -200,10 +202,11 @@ struct<> -- !query output org.apache.spark.SparkDateTimeException { - "errorClass" : "_LEGACY_ERROR_TEMP_2000", + "errorClass" : "DATETIME_FIELD_OUT_OF_BOUNDS", + "sqlState" : "22023", "messageParameters" : { "ansiConfig" : "\"spark.sql.ansi.enabled\"", - "message" : "Invalid value for SecondOfMinute (valid values 0 - 59): 999" + "rangeMessage" : "Invalid value for SecondOfMinute (valid values 0 - 59): 999" } } From 5cc60f46708844c812ee0f21bee4f4b4b70c6d92 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 13 Nov 2024 14:57:53 -0800 Subject: [PATCH 30/39] [SPARK-50300][BUILD] Use mirror host instead of `archive.apache.org` ### What changes were proposed in this pull request? This PR aims to use `mirror host` instead of `archive.apache.org`. ### Why are the changes needed? Currently, Apache Spark CI is flaky due to the checksum download failure like the following. It took over 9 minutes and failed eventually. - https://github.com/apache/spark/actions/runs/11818847971/job/32927380452 - https://github.com/apache/spark/actions/runs/11818847971/job/32927382179 ``` exec: curl --retry 3 --silent --show-error -L https://www.apache.org/dyn/closer.lua/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz?action=download exec: curl --retry 3 --silent --show-error -L https://archive.apache.org/dist/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz.sha512 curl: (28) Failed to connect to archive.apache.org port 443 after 135199 ms: Connection timed out curl: (28) Failed to connect to archive.apache.org port 443 after 134166 ms: Connection timed out curl: (28) Failed to connect to archive.apache.org port 443 after 135213 ms: Connection timed out curl: (28) Failed to connect to archive.apache.org port 443 after 135260 ms: Connection timed out Verifying checksum from /home/runner/work/spark/spark/build/apache-maven-3.9.9-bin.tar.gz.sha512 shasum: /home/runner/work/spark/spark/build/apache-maven-3.9.9-bin.tar.gz.sha512: no properly formatted SHA checksum lines found Bad checksum from https://archive.apache.org/dist/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz.sha512 Error: Process completed with exit code 2. ``` **BEFORE** ``` $ build/mvn clean exec: curl --retry 3 --silent --show-error -L https://www.apache.org/dyn/closer.lua/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz?action=download exec: curl --retry 3 --silent --show-error -L https://archive.apache.org/dist/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz.sha512 ``` **AFTER** ``` $ build/mvn clean exec: curl --retry 3 --silent --show-error -L https://www.apache.org/dyn/closer.lua/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz?action=download exec: curl --retry 3 --silent --show-error -L https://www.apache.org/dyn/closer.lua/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz.sha512?action=download ``` ### Does this PR introduce _any_ user-facing change? No, this is a dev-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48836 from dongjoon-hyun/SPARK-50300. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index 060209ac1ac4d..fef589fc03476 100755 --- a/build/mvn +++ b/build/mvn @@ -56,7 +56,7 @@ install_app() { local binary="${_DIR}/$6" local remote_tarball="${mirror_host}/${url_path}${url_query}" local local_checksum="${local_tarball}.${checksum_suffix}" - local remote_checksum="https://archive.apache.org/dist/${url_path}.${checksum_suffix}" + local remote_checksum="${mirror_host}/${url_path}.${checksum_suffix}${url_query}" local curl_opts="--retry 3 --silent --show-error -L" local wget_opts="--no-verbose" From 33378a6f86e001b20236c7ccd1cebf0acbb54f3e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 14 Nov 2024 09:01:19 +0900 Subject: [PATCH 31/39] [SPARK-50304][INFRA] Remove `(any|empty).proto` from RAT exclusion ### What changes were proposed in this pull request? This PR aims to remove `(any|empty).proto` from RAT exclusion. ### Why are the changes needed? `(any|empty).proto` files were never a part of Apache Spark repository. Those files were only used in the initial `Connect` PR and removed before merging. - #37710 - Added: https://github.com/apache/spark/pull/37710/commits/45c7bc55498f38081818424d231ec12576a0dc54 - Excluded from RAT check: https://github.com/apache/spark/pull/37710/commits/cf6b19a991c9bf8c0f208bb2de39dd7121b146a2 - Removed: https://github.com/apache/spark/pull/37710/commits/497198051af069f9afa70c9435dd5d7a099f11f1 ### Does this PR introduce _any_ user-facing change? No. This is a dev-only change. ### How was this patch tested? Pass the CIs or manual check. ``` $ ./dev/check-license Ignored 0 lines in your exclusion files as comments or empty lines. RAT checks passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48837 from dongjoon-hyun/SPARK-50304. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- dev/.rat-excludes | 3 --- 1 file changed, 3 deletions(-) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 6806c24c7d9fd..d8c9196293950 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -126,9 +126,6 @@ exported_table/* node_modules spark-events-broken/* SqlBaseLexer.tokens -# Spark Connect related files with custom licence -any.proto -empty.proto .*\.explain .*\.proto.bin LimitedInputStream.java From 891f694207ea83dcfd2ec53e72ca6f0daa093924 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 14 Nov 2024 10:57:04 +0900 Subject: [PATCH 32/39] [SPARK-50306][PYTHON][CONNECT] Support Python 3.13 in Spark Connect ### What changes were proposed in this pull request? This PR proposes to note Python 3.13 in `pyspark-connect` package as its supported version. ### Why are the changes needed? To officially support Python 3.13 ### Does this PR introduce _any_ user-facing change? Yes, in `pyspark-connect` package, Python 3.13 will be explicitly noted as a supported Python version. ### How was this patch tested? CI passed at https://github.com/apache/spark/actions/runs/11824865909 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48839 from HyukjinKwon/SPARK-50306. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/packaging/connect/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index 6ae16e9a9ad3a..de76d51d0cfdc 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -212,6 +212,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Typing :: Typed", From 2fd47026371488b9409750cba6b697cc61ea7371 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 14 Nov 2024 09:59:01 +0800 Subject: [PATCH 33/39] [SPARK-49913][SQL] Add check for unique label names in nested labeled scopes ### What changes were proposed in this pull request? We are introducing checks for unique label names. New rules for label names: - Labels can't have the same name as some of the labels in scope surrounding them - Labels can have the same name as other labels in the same scope **Valid** code: ``` BEGIN lbl: BEGIN SELECT 1; END; lbl: BEGIN SELECT 2; END; BEGIN lbl: WHILE 1=1 DO LEAVE lbl; END WHILE; END; END ``` **Invalid** code: ``` BEGIN lbl: BEGIN lbl: BEGIN SELECT 1; END; END; END ``` #### Design explanation: Even though there are _Listeners_ with `enterRule` and `exitRule` methods to check labels before and remove them from `seenLabels` after visiting node, we favor this approach because minimal changes were needed and code is more compact to avoid dependency issues. Additionally, generating label text would need to be done in 2 places and we wanted to avoid duplicated logic: - `enterRule` - `visitRule` ### Why are the changes needed? It will be needed in future when we release Local Scoped Variables for SQL Scripting so users can target variables from outer scopes if they are shadowed. ### How was this patch tested? New unit tests in 'SqlScriptingParserSuite.scala'. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48795 from miland-db/milan-dankovic_data/unique_labels_scripting. Authored-by: Milan Dankovic Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 6 + .../sql/catalyst/parser/AstBuilder.scala | 151 ++++++----- .../sql/catalyst/parser/ParserUtils.scala | 84 +++++- .../spark/sql/errors/SqlScriptingErrors.scala | 8 + .../parser/SqlScriptingParserSuite.scala | 242 ++++++++++++++++++ 5 files changed, 430 insertions(+), 61 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index eb772f053a889..63c54a71b904b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3411,6 +3411,12 @@ ], "sqlState" : "42K0L" }, + "LABEL_ALREADY_EXISTS" : { + "message" : [ + "The label