Skip to content

Commit

Permalink
Add new suites
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Nov 22, 2024
1 parent db160c2 commit aa2047f
Show file tree
Hide file tree
Showing 8 changed files with 1,790 additions and 33 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5375,6 +5375,11 @@
"SQL Scripting is under development and not all features are supported. SQL Scripting enables users to write procedural SQL including control flow and error handling. To enable existing features set <sqlScriptingEnabled> to `true`."
]
},
"SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS" : {
"message" : [
"Positional parameters are not supported with SQL Scripting."
]
},
"STATE_STORE_MULTIPLE_COLUMN_FAMILIES" : {
"message" : [
"Creating multiple column families with <stateStoreProvider> is not supported."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ private[sql] object SqlScriptingErrors {
messageParameters = Map("invalidStatement" -> toSQLStmt(stmt)))
}

def positionalParametersAreNotSupportedWithSqlScripting(): Throwable = {
new SqlScriptingException(
origin = null,
errorClass = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS",
cause = null,
messageParameters = Map.empty)
}

def labelDoesNotExist(
origin: Origin,
labelName: String,
Expand Down
89 changes: 70 additions & 19 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, SparkException, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
Expand All @@ -44,17 +42,19 @@ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParame
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExternalCommandExecutor
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.scripting.SqlScriptingExecution
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -390,11 +390,6 @@ class SparkSession private(
Dataset.ofRows(self, logicalPlan)
}

private def executeSqlScript(): Unit = {

}


/* ------------------------- *
| Catalog-related methods |
* ------------------------- */
Expand All @@ -415,6 +410,32 @@ class SparkSession private(
| Everything else |
* ----------------- */

private def executeSqlScript(script: CompoundBody): DataFrame = {
val sse = new SqlScriptingExecution(script, this)
var df: DataFrame = null
var result: Option[Seq[Row]] = null

while (sse.hasNext) {
sse.withErrorHandling() {
df = sse.next()
if (sse.hasNext) {
df.collect()
} else {
// Collect results from the last DataFrame
result = Some(df.collect().toSeq)
}
}
}

if (result == null) {
emptyDataFrame
} else {
val attributes = DataTypeUtils.toAttributes(result.get.head.schema)
Dataset.ofRows(
self, LocalRelation.fromExternalRows(attributes, result.get))
}
}

/**
* Executes a SQL query substituting positional parameters by the given arguments,
* returning the result as a `DataFrame`.
Expand All @@ -434,13 +455,30 @@ class SparkSession private(
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
if (args.nonEmpty) {
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
} else {
parsedPlan
parsedPlan match {
case compoundBody: CompoundBody =>
if (args.nonEmpty) {
// Positional parameters are not supported for SQL scripting
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
}
compoundBody
case logicalPlan: LogicalPlan =>
if (args.nonEmpty) {
PosParameterizedQuery(logicalPlan, args.map(lit(_).expr).toImmutableArraySeq)
} else {
logicalPlan
}
}
}
Dataset.ofRows(self, plan, tracker)

plan match {
case compoundBody: CompoundBody =>
// execute the SQL script
executeSqlScript(compoundBody)
case logicalPlan: LogicalPlan =>
// execute the standalone SQL statement
Dataset.ofRows(self, plan, tracker)
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -472,13 +510,26 @@ class SparkSession private(
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
if (args.nonEmpty) {
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
} else {
parsedPlan
parsedPlan match {
case compoundBody: CompoundBody =>
compoundBody
case logicalPlan: LogicalPlan =>
if (args.nonEmpty) {
NameParameterizedQuery(logicalPlan, args.transform((_, v) => lit(v).expr))
} else {
logicalPlan
}
}
}
Dataset.ofRows(self, plan, tracker)

plan match {
case compoundBody: CompoundBody =>
// execute the SQL script
executeSqlScript(compoundBody)
case logicalPlan: LogicalPlan =>
// execute the standalone SQL statement
Dataset.ofRows(self, plan, tracker)
}
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.scripting

import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody

/**
Expand All @@ -30,7 +30,7 @@ class SqlScriptingExecution(
private val executionPlan: Iterator[CompoundStatementExec] =
SqlScriptingInterpreter().buildExecutionPlan(sqlScript, session)

private var current = if (executionPlan.hasNext) Some(executionPlan.next()) else None
private var current = getNextResult

override def hasNext: Boolean = {
current.isDefined
Expand All @@ -40,19 +40,38 @@ class SqlScriptingExecution(
if (!hasNext) {
throw new NoSuchElementException("No more statements to execute")
}
moveCurrentToNextResult()
current.get.asInstanceOf[SingleStatementExec].buildDataFrame(session)
val res = current.get.asInstanceOf[SingleStatementExec].buildDataFrame(session)
current = getNextResult
res
}

private def moveCurrentToNextResult(): Unit = {
while (current.isDefined && !current.get.isResult) {
current.get match {
case exec: SingleStatementExec =>
exec.buildDataFrame(session).collect()
case _ => // Do nothing
private def getNextResult: Option[CompoundStatementExec] = {
var currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
while (currentStatement.isDefined && !currentStatement.get.isResult) {
currentStatement match {
case Some(stmt: SingleStatementExec) if !stmt.isExecuted =>
withErrorHandling() {
stmt.buildDataFrame(session).collect()
}
case _ => // pass
}
current = if (executionPlan.hasNext) Some(executionPlan.next()) else None
currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
}
currentStatement
}

}
private def handleException(e: Exception): Unit = {
// Rethrow the exception
// TODO: SPARK-48353 Add error handling for SQL scripts
throw e
}

def withErrorHandling()(f: => Unit): Unit = {
try {
f
} catch {
case e: Exception =>
handleException(e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ trait NonLeafStatementExec extends CompoundStatementExec {

// DataFrame evaluates to True if it is single row, single column
// of boolean type with value True.
val df = Dataset.ofRows(session, statement.parsedPlan)
val df = statement.buildDataFrame(session)
df.schema.fields match {
case Array(field) if field.dataType == BooleanType =>
df.limit(2).collect() match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.scripting

import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession


/**
* End-to-end tests for SQL Scripting.
* This suite is not intended to heavily test the SQL scripting (parser & interpreter) logic.
* It is rather focused on testing the sql() API - whether it can handle SQL scripts correctly,
* results are returned in expected manner, config flags are applied properly, etc.
* For full functionality tests, see SqlScriptingParserSuite and SqlScriptingInterpreterSuite.
*/
class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
// Helpers
private def verifySqlScriptResult(sqlText: String, expected: Seq[Row]): Unit = {
val df = spark.sql(sqlText)
checkAnswer(df, expected)
}

// Tests setup
override protected def sparkConf: SparkConf = {
super.sparkConf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true")
}

// Tests
test("SQL Scripting not enabled") {
withSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key -> "false") {
val sqlScriptText =
"""
|BEGIN
| SELECT 1;
|END""".stripMargin
checkError(
exception = intercept[SqlScriptingException] {
spark.sql(sqlScriptText).asInstanceOf[CompoundBody]
},
condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING",
parameters = Map("sqlScriptingEnabled" -> toSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key)))
}
}

test("single select") {
val sqlText = "SELECT 1;"
verifySqlScriptResult(sqlText, Seq(Row(1)))
}

test("multiple selects") {
val sqlText =
"""
|BEGIN
| SELECT 1;
| SELECT 2;
|END""".stripMargin
verifySqlScriptResult(sqlText, Seq(Row(2)))
}

test("multi statement - simple") {
withTable("t") {
val sqlScript =
"""
|BEGIN
| CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
| INSERT INTO t VALUES (1, 'a', 1.0);
| SELECT a FROM t;
|END
|""".stripMargin
verifySqlScriptResult(sqlScript, Seq(Row(1)))
}
}

test("last statement without result") {
val sqlScript =
"""
|BEGIN
| DECLARE x INT;
| SET x = 1;
| DROP TEMPORARY VARIABLE x;
|END
|""".stripMargin
verifySqlScriptResult(sqlScript, Seq.empty)
}

test("positional params") {
val sqlScriptText =
"""
|BEGIN
| SELECT 1;
| IF ? > 10 THEN
| SELECT ?;
| ELSE
| SELECT ?;
| END IF;
|END""".stripMargin
// Define an array with SQL parameters in the correct order
val args: Array[Any] = Array(5, "greater", "smaller")
checkError(
exception = intercept[SqlScriptingException] {
spark.sql(sqlScriptText, args).asInstanceOf[CompoundBody]
},
condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS",
parameters = Map.empty)
}
}
Loading

0 comments on commit aa2047f

Please sign in to comment.