Skip to content

Commit

Permalink
[SPARK-42680][CONNECT][TESTS] Create the helper function withSQLConf …
Browse files Browse the repository at this point in the history
…for connect test framework

### What changes were proposed in this pull request?
Spark SQL have the helper function `withSQLConf` that is easy to change SQL config and make test easy.

### Why are the changes needed?
Make the connect test cases easy to implement.

### Does this PR introduce _any_ user-facing change?
No, it is a test only change.

### How was this patch tested?
Test case updated.

Closes apache#40296 from beliefer/SPARK-42680.

Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
beliefer authored and HyukjinKwon committed Mar 7, 2023
1 parent dfdc4a1 commit 201e08c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSpa
import org.apache.spark.sql.functions.{aggregate, array, broadcast, col, count, lit, rand, sequence, shuffle, struct, transform, udf}
import org.apache.spark.sql.types._

class ClientE2ETestSuite extends RemoteSparkSession {
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {

// Spark Result
test("spark result schema") {
Expand Down Expand Up @@ -501,16 +501,13 @@ class ClientE2ETestSuite extends RemoteSparkSession {
}

test("broadcast join") {
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
try {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val left = spark.range(100).select(col("id"), rand(10).as("a"))
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val joined =
left.join(broadcast(right), left("id") === right("id")).select(left("id"), right("a"))
assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
testCapturedStdOut(joined.explain(), "BroadcastHashJoin")
} finally {
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

trait SQLHelper {

def spark: SparkSession

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (spark.conf.getOption(key).isDefined) {
Some(spark.conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach { (k, v) =>
if (spark.conf.isModifiable(k)) {
spark.conf.set(k, v)
} else {
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
}

}
try f
finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => spark.conf.set(key, value)
case (key, None) => spark.conf.unset(key)
}
}
}
}

0 comments on commit 201e08c

Please sign in to comment.