Skip to content

Commit

Permalink
[SPARK-49249][SPARK-49122] Artifact isolation in Spark Classic
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR makes the isolation feature introduced by `SparkSession.addArtifact` API (added in #47631) work with Spark SQL.

Note that this PR does not enable isolation for the following two use cases:

- PySpark
	- Future work is needed to add API to support adding isolated Python UDTFs.
- When Hive is used as the metastore
	- Hive UDF is a huge blocker due to artifacts can be used outside a `SparkSession`, which resources escaped from our session scope.

### Why are the changes needed?

Because it didn't work before :)

### Does this PR introduce _any_ user-facing change?

Yes, the user can add a new artifact in the REPL and use it in the current REPL session.

### How was this patch tested?

Added a new test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48120 from xupefei/session-artifact-apply.

Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
xupefei authored and HyukjinKwon committed Nov 13, 2024
1 parent 1198117 commit 2633035
Show file tree
Hide file tree
Showing 23 changed files with 344 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.UdfToProtoUtils
import org.apache.spark.sql.types.DataType

/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to access this:
Expand All @@ -30,6 +31,11 @@ import org.apache.spark.sql.internal.UdfToProtoUtils
* @since 3.5.0
*/
class UDFRegistration(session: SparkSession) extends api.UDFRegistration {
override def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
throw new UnsupportedOperationException(
"registerJava is currently not supported in Spark Connect.")
}

override protected def register(
name: String,
udf: UserDefinedFunction,
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkFiles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ object SparkFiles {
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
def get(filename: String): String =
new File(getRootDirectory(), filename).getAbsolutePath()
def get(filename: String): String = {
val jobArtifactUUID = JobArtifactSet
.getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
val withUuid = if (jobArtifactUUID == "default") filename else s"$jobArtifactUUID/$filename"
new File(getRootDirectory(), withUuid).getAbsolutePath
}

/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
DEFAULT_CONFIGS: Dict[str, Any] = {
"spark.serializer.objectStreamReset": 100,
"spark.rdd.compress": True,
# Disable artifact isolation in PySpark, or user-added .py file won't work
"spark.sql.artifact.isolation.enabled": "false",
}

T = TypeVar("T")
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,11 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
os.environ["SPARK_LOCAL_CONNECT"] = "1"

# Configurations to be set if unset.
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
default_conf = {
"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin",
"spark.sql.artifact.isolation.enabled": "true",
"spark.sql.artifact.isolation.always.apply.classloader": "true",
}

if "SPARK_TESTING" in os.environ:
# For testing, we use 0 to use an ephemeral port to allow parallel testing.
Expand Down
4 changes: 4 additions & 0 deletions repl/src/main/scala/org/apache/spark/repl/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.tools.nsc.GenericRunnerSettings
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -95,6 +96,9 @@ object Main extends Logging {
// initialization in certain cases, there's an initialization order issue that prevents
// this from being set after SparkContext is instantiated.
conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
// Disable isolation for REPL, to avoid having in-line classes stored in a isolated directory,
// prevent the REPL classloader from finding it.
conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, false)
if (execUri != null) {
conf.set("spark.executor.uri", execUri)
}
Expand Down
Binary file added repl/src/test/resources/IntSumUdf.class
Binary file not shown.
22 changes: 22 additions & 0 deletions repl/src/test/resources/IntSumUdf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import org.apache.spark.sql.api.java.UDF2

class IntSumUdf extends UDF2[Long, Long, Long] {
override def call(t1: Long, t2: Long): Long = t1 + t2
}
63 changes: 63 additions & 0 deletions repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,67 @@ class ReplSuite extends SparkFunSuite {
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}

test("register UDF via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.api.java.UDF2
|import org.apache.spark.sql.types.DataTypes
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)

// The UDF should not work in a new REPL session.
val anotherOutput = runInterpreterInPasteMode("local",
s"""
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|
""".stripMargin)
assertContains(
"[UNRESOLVED_ROUTINE] Cannot resolve routine `intSum` on search path",
anotherOutput)
}

test("register a class via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.functions.udf
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|val intSumUdf = udf((x: Long, y: Long) => new IntSumUdf().call(x, y))
|spark.udf.register("intSum", intSumUdf)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ import org.apache.spark.sql.types.DataType
*/
abstract class UDFRegistration {

/**
* Register a Java UDF class using it's class name. The class must implement one of the UDF
* interfaces in the [[org.apache.spark.sql.api.java]] package, and discoverable by the current
* session's class loader.
*
* @param name
* Name of the UDF.
* @param className
* Fully qualified class name of the UDF.
* @param returnDataType
* Return type of UDF. If it is `null`, Spark would try to infer via reflection.
* @note
* this method is currently not supported in Spark Connect.
* @since 4.0.0
*/
def registerJava(name: String, className: String, returnDataType: DataType): Unit

/**
* Registers a user-defined function (UDF), for a UDF that's already defined using the Dataset
* API (i.e. of type UserDefinedFunction). To change a UDF to nondeterministic, call the API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3957,6 +3957,28 @@ object SQLConf {
.intConf
.createWithDefault(20)

val ARTIFACTS_SESSION_ISOLATION_ENABLED =
buildConf("spark.sql.artifact.isolation.enabled")
.internal()
.doc("When enabled for a Spark Session, artifacts (such as JARs, files, archives) added to " +
"this session are isolated from other sessions within the same Spark instance. When " +
"disabled for a session, artifacts added to this session are visible to other sessions " +
"that have this config disabled. This config can only be set during the creation of a " +
"Spark Session and will have no effect when changed in the middle of session usage.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER =
buildConf("spark.sql.artifact.isolation.always.apply.classloader")
.internal()
.doc("When enabled, the classloader holding per-session artifacts will always be applied " +
"during SQL executions (useful for Spark Connect). When disabled, the classloader will " +
"be applied only when any artifact is added to the session.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT =
buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.sys.exit
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.internal.SQLConf

/**
* A simple main class method to start the spark connect server as a service for client tests
Expand All @@ -40,6 +41,8 @@ private[sql] object SimpleSparkConnectService {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin")
.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, true)
.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER, true)
val sparkSession = SparkSession.builder().config(conf).getOrCreate()
val sparkContext = sparkSession.sparkContext // init spark context
// scalastyle:off println
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

/**
* The Spark Connect server
Expand All @@ -28,7 +29,11 @@ object SparkConnectServer extends Logging {
def main(args: Array[String]): Unit = {
// Set the active Spark Session, and starts SparkEnv instance (via Spark Context)
logInfo("Starting Spark session.")
val session = SparkSession.builder().getOrCreate()
val session = SparkSession
.builder()
.config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, true)
.config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key, true)
.getOrCreate()
try {
try {
SparkConnectService.start(session.sparkContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,11 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging {
/** @inheritdoc */
override def enableHiveSupport(): this.type = synchronized {
if (hiveClassesArePresent) {
// TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR` command. This will
// break an existing Hive use case (one session adds JARs and another session uses them).
// We need to decide whether/how to enable isolation for Hive.
super.enableHiveSupport()
.config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, false)
} else {
throw new IllegalArgumentException(
"Unable to instantiate SparkSession with Hive support because " +
Expand Down
18 changes: 5 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to access this:
Expand All @@ -44,7 +43,7 @@ import org.apache.spark.util.Utils
* @since 1.3.0
*/
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
class UDFRegistration private[sql] (session: SparkSession, functionRegistry: FunctionRegistry)
extends api.UDFRegistration
with Logging {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
Expand Down Expand Up @@ -121,7 +120,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
*/
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw QueryCompilationErrors
.classDoesNotImplementUserDefinedAggregateFunctionError(className)
Expand All @@ -137,17 +136,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
}

// scalastyle:off line.size.limit
/**
* Register a Java UDF class using reflection, for use from pyspark
*
* @param name udf name
* @param className fully qualified class name of udf
* @param returnDataType return type of udf. If it is null, spark would try to infer
* via reflection.
*/
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {

override def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
Expand Down
Loading

0 comments on commit 2633035

Please sign in to comment.