diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 94bc22ef77d0e..d2724181cd405 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSpa import org.apache.spark.sql.functions.{aggregate, array, broadcast, col, count, lit, rand, sequence, shuffle, struct, transform, udf} import org.apache.spark.sql.types._ -class ClientE2ETestSuite extends RemoteSparkSession { +class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { // Spark Result test("spark result schema") { @@ -501,16 +501,13 @@ class ClientE2ETestSuite extends RemoteSparkSession { } test("broadcast join") { - spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") - try { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { val left = spark.range(100).select(col("id"), rand(10).as("a")) val right = spark.range(100).select(col("id"), rand(12).as("a")) val joined = left.join(broadcast(right), left("id") === right("id")).select(left("id"), right("a")) assert(joined.schema.catalogString === "struct") testCapturedStdOut(joined.explain(), "BroadcastHashJoin") - } finally { - spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala new file mode 100644 index 0000000000000..002785a57c006 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +trait SQLHelper { + + def spark: SparkSession + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (spark.conf.getOption(key).isDefined) { + Some(spark.conf.get(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (spark.conf.isModifiable(k)) { + spark.conf.set(k, v) + } else { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + + } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } +}