Skip to content

Commit

Permalink
[SPARK-45380][CORE][SQL][ML][R][CONNECT] Replace `mutable.WrappedArra…
Browse files Browse the repository at this point in the history
…y` with `mutable.ArraySeq`

### What changes were proposed in this pull request?
This pr replace all `mutable.WrappedArray` with `mutable.ArraySeq`. Simultaneously, this PR unifies all uses of `scala.collection.mutable.ArraySeq` by first `import scala.collection.mutable` and then using `mutable.ArraySeq`.

### Why are the changes needed?
In Scala 2.13.0 and later, `WrappedArray` is marked as deprecated, its replacement is `mutable.ArraySeq`.

```scala
package object mutable {
  deprecated("Use ArraySeq instead of WrappedArray; it can represent both, boxed and unboxed arrays", "2.13.0")
  type WrappedArray[X] = ArraySeq[X]
  deprecated("Use ArraySeq instead of WrappedArray; it can represent both, boxed and unboxed arrays", "2.13.0")
  val WrappedArray = ArraySeq
...
}
```

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Actions

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43178 from LuciferYang/SPARK-45380.

Authored-by: yangjie01 <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
LuciferYang authored and dongjoon-hyun committed Sep 29, 2023
1 parent 0b68e41 commit 4863dec
Show file tree
Hide file tree
Showing 19 changed files with 53 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3041,7 +3041,7 @@ class PlanGenerationTestSuite
fn.lit('T'),
fn.lit(Array.tabulate(10)(i => ('A' + i).toChar)),
fn.lit(Array.tabulate(23)(i => (i + 120).toByte)),
fn.lit(mutable.WrappedArray.make(Array[Byte](8.toByte, 6.toByte))),
fn.lit(mutable.ArraySeq.make(Array[Byte](8.toByte, 6.toByte))),
fn.lit(null),
fn.lit(java.time.LocalDate.of(2020, 10, 10)),
fn.lit(Decimal.apply(BigDecimal(8997620, 6))),
Expand Down Expand Up @@ -3110,7 +3110,7 @@ class PlanGenerationTestSuite
fn.typedLit('T'),
fn.typedLit(Array.tabulate(10)(i => ('A' + i).toChar)),
fn.typedLit(Array.tabulate(23)(i => (i + 120).toByte)),
fn.typedLit(mutable.WrappedArray.make(Array[Byte](8.toByte, 6.toByte))),
fn.typedLit(mutable.ArraySeq.make(Array[Byte](8.toByte, 6.toByte))),
fn.typedLit(null),
fn.typedLit(java.time.LocalDate.of(2020, 10, 10)),
fn.typedLit(Decimal.apply(BigDecimal(8997620, 6))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
}

test("wrapped array") {
val encoder = ScalaReflection.encoderFor[mutable.WrappedArray[Int]]
val input = mutable.WrappedArray.make[Int](Array(1, 98, 7, 6))
val encoder = ScalaReflection.encoderFor[mutable.ArraySeq[Int]]
val input = mutable.ArraySeq.make[Int](Array(1, 98, 7, 6))
val iterator = roundTrip(encoder, Iterator.single(input))
val Seq(result) = iterator.toSeq
assert(result == input)
Expand All @@ -511,7 +511,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
val encoder = toRowEncoder(schema)
val iterator = roundTrip(encoder, Iterator.single(Row(Seq())))
val Seq(Row(raw)) = iterator.toSeq
val seq = raw.asInstanceOf[mutable.WrappedArray[String]]
val seq = raw.asInstanceOf[mutable.ArraySeq[String]]
assert(seq.isEmpty)
assert(seq.array.getClass == classOf[Array[String]])
iterator.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ object ArrowDeserializers {

case (IterableEncoder(tag, element, _, _), v: ListVector) =>
val deserializer = deserializerFor(element, v.getDataVector, timeZoneId)
if (isSubClass(Classes.WRAPPED_ARRAY, tag)) {
// Wrapped array is a bit special because we need to use an array of the element type.
if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) {
// mutable ArraySeq is a bit special because we need to use an array of the element type.
// Some parts of our codebase (unfortunately) rely on this for type inference on results.
new VectorFieldDeserializer[mutable.WrappedArray[Any], ListVector](v) {
def value(i: Int): mutable.WrappedArray[Any] = {
new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) {
def value(i: Int): mutable.ArraySeq[Any] = {
val array = getArray(vector, i, deserializer)(element.clsTag)
ScalaCollectionUtils.wrap(array)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.client.arrow

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand All @@ -24,7 +25,7 @@ import org.apache.arrow.vector.complex.StructVector

private[arrow] object ArrowEncoderUtils {
object Classes {
val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]]
val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]]
val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
val JLIST: Class[_] = classOf[java.util.List[_]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private[arrow] object ScalaCollectionUtils {
def getMapCompanion(tag: ClassTag[_]): MapFactory[Map] = {
resolveCompanion[MapFactory[Map]](tag)
}
def wrap[T](array: AnyRef): mutable.WrappedArray[T] = {
mutable.WrappedArray.make(array.asInstanceOf[Array[T]])
def wrap[T](array: AnyRef): mutable.ArraySeq[T] = {
mutable.ArraySeq.make(array.asInstanceOf[Array[T]])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object LiteralValueProtoConverter {
case v: Char => builder.setString(v.toString)
case v: Array[Char] => builder.setString(String.valueOf(v))
case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v))
case v: collection.mutable.WrappedArray[_] => toLiteralProtoBuilder(v.array)
case v: mutable.ArraySeq[_] => toLiteralProtoBuilder(v.array)
case v: LocalDate => builder.setDate(v.toEpochDay.toInt)
case v: Decimal =>
builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString))
Expand Down Expand Up @@ -162,7 +162,7 @@ object LiteralValueProtoConverter {
}

(literal, dataType) match {
case (v: collection.mutable.WrappedArray[_], ArrayType(_, _)) =>
case (v: mutable.ArraySeq[_], ArrayType(_, _)) =>
toLiteralProtoBuilder(v.array, dataType)
case (v: Array[Byte], ArrayType(_, _)) =>
toLiteralProtoBuilder(v)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/SerDe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ private[spark] object SerDe {
} else {
// Convert ArrayType collected from DataFrame to Java array
// Collected data of ArrayType from a DataFrame is observed to be of
// type "scala.collection.mutable.WrappedArray"
// type "scala.collection.mutable.ArraySeq"
val value = obj match {
case wa: mutable.WrappedArray[_] => wa.array
case wa: mutable.ArraySeq[_] => wa.array
case other => other
}

Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import javax.ws.rs.core.UriBuilder

import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, WrappedArray}
import scala.collection.{immutable, mutable}
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand Down Expand Up @@ -741,9 +741,9 @@ private[spark] class Executor(
logInfo(s"Executor killed $taskName, reason: ${t.reason}")

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
// Here and below, put task metric peaks in a ArraySeq to expose them as a Seq
// without requiring a copy.
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val metricPeaks = mutable.ArraySeq.make(metricsPoller.getTaskMetricPeaks(taskId))
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
Expand All @@ -754,7 +754,7 @@ private[spark] class Executor(
logInfo(s"Executor interrupted and killed $taskName, reason: $killReason")

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val metricPeaks = mutable.ArraySeq.make(metricsPoller.getTaskMetricPeaks(taskId))
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
Expand Down Expand Up @@ -798,7 +798,7 @@ private[spark] class Executor(
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val metricPeaks = mutable.ArraySeq.make(metricsPoller.getTaskMetricPeaks(taskId))

val (taskFailureReason, serializedTaskFailureReason) = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[r] class LDAWrapper private (
if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
topicIndices
} else {
val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
val index2term = udf { indices: mutable.ArraySeq[Int] => indices.map(i => vocabulary(i)) }
topicIndices
.select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File
import java.util.Random

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.commons.io.FileUtils
Expand Down Expand Up @@ -1012,7 +1012,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
assert(recs === expected(id))
}
topK.collect().foreach { row =>
val recs = row.getAs[WrappedArray[Row]]("recommendations")
val recs = row.getAs[mutable.ArraySeq[Row]]("recommendations")
assert(recs(0).fieldIndex(dstColName) == 0)
assert(recs(0).fieldIndex("rating") == 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object RowEncoder {
UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
case ArrayType(elementType, containsNull) =>
IterableEncoder(
classTag[mutable.WrappedArray[_]],
classTag[mutable.ArraySeq[_]],
encoderForDataType(elementType, lenient),
containsNull,
lenientSerialization = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
})

// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
val wrapperClass = classOf[mutable.ArraySeq[_]].getName
ev.copy(code =
code"""
|$code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffse
import java.util
import java.util.Objects

import scala.collection.mutable
import scala.math.{BigDecimal, BigInt}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
Expand Down Expand Up @@ -88,7 +89,7 @@ object Literal {
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType())
case p: Period => Literal(periodToMonths(p), YearMonthIntervalType())
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: mutable.ArraySeq[_] => apply(a.array)
case a: Array[_] =>
val elementType = componentTypeToDataType(a.getClass.getComponentType())
val dataType = ArrayType(elementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.{Method, Modifier}

import scala.collection.mutable
import scala.collection.mutable.{Builder, WrappedArray}
import scala.collection.mutable.Builder
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.Try
Expand Down Expand Up @@ -915,15 +915,15 @@ case class MapObjects private(
}

private lazy val mapElements: scala.collection.Seq[_] => Any = customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
case Some(cls) if classOf[mutable.ArraySeq[_]].isAssignableFrom(cls) =>
// The implicit tag is a workaround to deal with a small change in the
// (scala) signature of ArrayBuilder.make between Scala 2.12 and 2.13.
implicit val tag: ClassTag[Any] = elementClassTag()
input => {
val builder = mutable.ArrayBuilder.make[Any]
builder.sizeHint(input.size)
executeFuncOnCollection(input).foreach(builder += _)
mutable.WrappedArray.make(builder.result())
mutable.ArraySeq.make(builder.result())
}
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
Expand Down Expand Up @@ -1081,7 +1081,7 @@ case class MapObjects private(

val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
case Some(cls) if classOf[mutable.ArraySeq[_]].isAssignableFrom(cls) =>
val tag = ctx.addReferenceObj("tag", elementClassTag())
val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
Expand All @@ -1092,7 +1092,7 @@ case class MapObjects private(
$builder.sizeHint($dataLength);
""",
(genValue: String) => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
s"(${cls.getName}) ${classOf[mutable.ArraySeq[_]].getName}$$." +
s"MODULE$$.make($builder.result());"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {

private def roundTripArray[T](dt: DataType, nullable: Boolean, data: Array[T]): Unit = {
val schema = new StructType().add("a", ArrayType(dt, nullable))
test(s"RowEncoder should return WrappedArray with properly typed array for $schema") {
test(s"RowEncoder should return mutable.ArraySeq with properly typed array for $schema") {
val encoder = ExpressionEncoder(schema).resolveAndBind()
val result = fromRow(encoder, toRow(encoder, Row(data))).getAs[mutable.WrappedArray[_]](0)
val result = fromRow(encoder, toRow(encoder, Row(data))).getAs[mutable.ArraySeq[_]](0)
assert(result.array.getClass === data.getClass)
assert(result === data)
}
Expand Down Expand Up @@ -473,13 +473,13 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}

test("Encoding an ArraySeq/WrappedArray in scala-2.13") {
test("Encoding an mutable.ArraySeq in scala-2.13") {
val schema = new StructType()
.add("headers", ArrayType(new StructType()
.add("key", StringType)
.add("value", BinaryType)))
val encoder = ExpressionEncoder(schema, lenient = true).resolveAndBind()
val data = Row(mutable.WrappedArray.make(Array(Row("key", "value".getBytes))))
val data = Row(mutable.ArraySeq.make(Array(Row("key", "value".getBytes))))
val row = encoder.createSerializer()(data)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffse
import java.time.temporal.ChronoUnit
import java.util.TimeZone

import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkException, SparkFunSuite}
Expand Down Expand Up @@ -207,7 +208,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkArrayLiteral(Array("a", "b", "c"))
checkArrayLiteral(Array(1.0, 4.0))
checkArrayLiteral(Array(new CalendarInterval(1, 0, 0), new CalendarInterval(0, 1, 0)))
val arr = collection.mutable.WrappedArray.make(Array(1.0, 4.0))
val arr = mutable.ArraySeq.make(Array(1.0, 4.0))
checkEvaluation(Literal(arr), toCatalyst(arr))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import scala.collection.mutable.WrappedArray
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -361,16 +361,16 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
case a if classOf[WrappedArray[Int]].isAssignableFrom(a) =>
assert(result == WrappedArray.make[Int](expected.toArray))
case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
assert(result == mutable.ArraySeq.make[Int](expected.toArray))
case s if classOf[Seq[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[Seq[_]] == expected)
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
}
}

val customCollectionClasses = Seq(classOf[WrappedArray[Int]],
val customCollectionClasses = Seq(classOf[mutable.ArraySeq[Int]],
classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
Expand All @@ -391,7 +391,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
stack.add(3)

Seq(
(Seq(1, 2, 3), ObjectType(classOf[WrappedArray[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[mutable.ArraySeq[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
(Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Object])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4542,8 +4542,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
.parquet(dir.getCanonicalPath)
checkAnswer(res,
Seq(
Row(1, false, mutable.WrappedArray.make(binary1)),
Row(2, true, mutable.WrappedArray.make(binary2))
Row(1, false, mutable.ArraySeq.make(binary1)),
Row(2, true, mutable.ArraySeq.make(binary2))
))
}
}
Expand Down
9 changes: 5 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.sql.Timestamp
import java.time.{Instant, LocalDate}
import java.time.format.DateTimeFormatter

import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SPARK_DOC_ROOT, SparkException}
import org.apache.spark.sql.api.java._
Expand Down Expand Up @@ -820,9 +821,9 @@ class UDFSuite extends QueryTest with SharedSparkSession {
checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
}

test("SPARK-32459: UDF should not fail on WrappedArray") {
val myUdf = udf((a: WrappedArray[Int]) =>
WrappedArray.make[Int](Array(a.head + 99)))
test("SPARK-32459: UDF should not fail on mutable.ArraySeq") {
val myUdf = udf((a: mutable.ArraySeq[Int]) =>
mutable.ArraySeq.make[Int](Array(a.head + 99)))
checkAnswer(Seq(Array(1))
.toDF("col")
.select(myUdf(Column("col"))),
Expand Down

0 comments on commit 4863dec

Please sign in to comment.