Skip to content

Commit

Permalink
[SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit fun…
Browse files Browse the repository at this point in the history
…ction for Scala Array to wrap into `immutable.ArraySeq`

### What changes were proposed in this pull request?
Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated.

So this PR introduces an implicit function `toImmutableArraySeq` to make it easier for Scala Array to be wrapped into `immutable.ArraySeq`.

After this pr, we can use the following way to wrap an array into an `immutable.ArraySeq`:

```scala
import org.apache.spark.util.ArrayImplicits._

val dataArray = ...
val immutableArraySeq = dataArray.toImmutableArraySeq
```

At the same time, this pr replaces the existing use of `immutable.ArraySeq.unsafeWrapArray(array)` with the new method.

On the other hand, this implicit function will be conducive to the progress of work SPARK-45686 and SPARK-45687.

### Why are the changes needed?
Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look less bloated.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Actions

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43607 from LuciferYang/SPARK-45742.

Authored-by: yangjie01 <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
LuciferYang authored and dongjoon-hyun committed Nov 2, 2023
1 parent a04d4e2 commit 30ec6e3
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import scala.collection.immutable

/**
* Implicit methods related to Scala Array.
*/
private[spark] object ArrayImplicits {

implicit class SparkArrayOps[T](xs: Array[T]) {

/**
* Wraps an Array[T] as an immutable.ArraySeq[T] without copying.
*/
def toImmutableArraySeq: immutable.ArraySeq[T] =
if (xs eq null) null
else immutable.ArraySeq.unsafeWrapArray(xs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.collection.immutable
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag

Expand All @@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

/**
* The entry point to programming Spark with the Dataset and DataFrame API.
Expand Down Expand Up @@ -248,7 +248,7 @@ class SparkSession private[sql] (
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava)))
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
// .toBuffer forces that the iterator is consumed and closed
val responseSeq = client.execute(plan.build()).toBuffer.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client

import java.time.DateTimeException

import scala.collection.immutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand All @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.util.ArrayImplicits._

/**
* GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions.
Expand Down Expand Up @@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter {
FetchErrorDetailsResponse.Error
.newBuilder()
.setMessage(message)
.addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava)
.addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava)
.build()))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.connect.planner

import scala.collection.immutable
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try
Expand Down Expand Up @@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQ
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.CacheId
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils

final case class InvalidCommandInput(
Expand Down Expand Up @@ -3184,9 +3184,9 @@ class SparkConnectPlanner(
case StreamingQueryManagerCommand.CommandCase.ACTIVE =>
val active_queries = session.streams.active
respBuilder.getActiveBuilder.addAllActiveQueries(
immutable.ArraySeq
.unsafeWrapArray(active_queries
.map(query => buildStreamingQueryInstance(query)))
active_queries
.map(query => buildStreamingQueryInstance(query))
.toImmutableArraySeq
.asJava)

case StreamingQueryManagerCommand.CommandCase.GET_QUERY =>
Expand Down Expand Up @@ -3265,15 +3265,16 @@ class SparkConnectPlanner(
.setGetResourcesCommandResult(
proto.GetResourcesCommandResult
.newBuilder()
.putAllResources(session.sparkContext.resources.view
.mapValues(resource =>
proto.ResourceInformation
.newBuilder()
.setName(resource.name)
.addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava)
.build())
.toMap
.asJava)
.putAllResources(
session.sparkContext.resources.view
.mapValues(resource =>
proto.ResourceInformation
.newBuilder()
.setName(resource.name)
.addAllAddresses(resource.addresses.toImmutableArraySeq.asJava)
.build())
.toMap
.asJava)
.build())
.build())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils
import java.util.UUID

import scala.annotation.tailrec
import scala.collection.immutable
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
Expand All @@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ArrayImplicits._

private[connect] object ErrorUtils extends Logging {

Expand Down Expand Up @@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging {

if (serverStackTraceEnabled) {
builder.addAllStackTrace(
immutable.ArraySeq
.unsafeWrapArray(currentError.getStackTrace
.map { stackTraceElement =>
val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement
.newBuilder()
.setDeclaringClass(stackTraceElement.getClassName)
.setMethodName(stackTraceElement.getMethodName)
.setLineNumber(stackTraceElement.getLineNumber)

if (stackTraceElement.getFileName != null) {
stackTraceBuilder.setFileName(stackTraceElement.getFileName)
}

stackTraceBuilder.build()
})
currentError.getStackTrace
.map { stackTraceElement =>
val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement
.newBuilder()
.setDeclaringClass(stackTraceElement.getClassName)
.setMethodName(stackTraceElement.getMethodName)
.setLineNumber(stackTraceElement.getLineNumber)

if (stackTraceElement.getFileName != null) {
stackTraceBuilder.setFileName(stackTraceElement.getFileName)
}

stackTraceBuilder.build()
}
.toImmutableArraySeq
.asJava)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import scala.collection.immutable

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.ArrayImplicits._

class ArrayImplicitsSuite extends SparkFunSuite {

test("Int Array") {
val data = Array(1, 2, 3)
val arraySeq = data.toImmutableArraySeq
assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt])
assert(arraySeq.length === 3)
assert(arraySeq.unsafeArray.sameElements(data))
}

test("TestClass Array") {
val data = Array(TestClass(1), TestClass(2), TestClass(3))
val arraySeq = data.toImmutableArraySeq
assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]])
assert(arraySeq.length === 3)
assert(arraySeq.unsafeArray.sameElements(data))
}

test("Null Array") {
val data: Array[Int] = null
val arraySeq = data.toImmutableArraySeq
assert(arraySeq == null)
}

case class TestClass(i: Int)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.mllib.api.python

import scala.collection.immutable
import scala.jdk.CollectionConverters._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.util.ArrayImplicits._

/**
* Wrapper around GaussianMixtureModel to provide helper methods in Python
Expand All @@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
val modelGaussians = model.gaussians.map { gaussian =>
Array[Any](gaussian.mu, gaussian.sigma)
}
SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava)
SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava)
}

def predictSoft(point: Vector): Vector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
*/
package org.apache.spark.mllib.api.python

import scala.collection.immutable
import scala.jdk.CollectionConverters._

import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.LDAModel
import org.apache.spark.mllib.linalg.Matrix
import org.apache.spark.util.ArrayImplicits._

/**
* Wrapper around LDAModel to provide helper methods in Python
Expand All @@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) {

def describeTopics(maxTermsPerTopic: Int): Array[Byte] = {
val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava
val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava
val jTerms = terms.toImmutableArraySeq.asJava
val jTermWeights = termWeights.toImmutableArraySeq.asJava
Array[Any](jTerms, jTermWeights)
}
SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava)
SerDe.dumps(topics.toImmutableArraySeq.asJava)
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
Expand Down

0 comments on commit 30ec6e3

Please sign in to comment.