Skip to content

Commit

Permalink
[SPARK-48463][ML] Make Binarizer, Bucketizer, VectorAssembler, Featur…
Browse files Browse the repository at this point in the history
…eHasher, QuantizeDiscretizer, OnehotEncoder, StopWordsRemover, Imputer, Interactor supporting nested input columns

### What changes were proposed in this pull request?

Make Binarizer, Bucketizer, VectorAssembler, FeatureHasher, QuantizeDiscretizer, OnehotEncoder, StopWordsRemover, Imputer, Interactor supporting nested input columns.

### Why are the changes needed?

Unit tests.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47719 from WeichenXu123/ML-43641.

Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Aug 13, 2024
1 parent acfe847 commit e7e0826
Show file tree
Hide file tree
Showing 19 changed files with 310 additions and 47 deletions.
14 changes: 12 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml.Transformer
Expand Down Expand Up @@ -117,7 +118,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}

val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) =>
dataset.schema(colName).dataType match {
dataset.col(colName).expr.dataType match {
case DoubleType =>
when(!col(colName).isNaN && col(colName) > td, lit(1.0))
.otherwise(lit(0.0))
Expand Down Expand Up @@ -199,7 +200,16 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output column $outputColName already exists.")
val inputType = schema(inputColName).dataType

val inputType = try {
SchemaUtils.getSchemaFieldType(schema, inputColName)
} catch {
case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" =>
throw new SparkException(s"Input column $inputColName does not exist.")
case e: Exception =>
throw e
}

val outputField = inputType match {
case DoubleType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
val n = $(numFeatures)
val localInputCols = $(inputCols)

var catCols = dataset.schema(localInputCols.toSet)
.filterNot(_.dataType.isInstanceOf[NumericType]).map(_.name).toArray
var catCols = localInputCols.map {
localInputCol => SchemaUtils.getSchemaField(dataset.schema, localInputCol)
}.filterNot(_.dataType.isInstanceOf[NumericType]).map(_.name)

if (isSet(categoricalCols)) {
// categoricalCols may contain columns not set in inputCols
catCols = (catCols ++ $(categoricalCols).intersect(localInputCols)).distinct
Expand Down Expand Up @@ -204,17 +206,17 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
log.warn(s"categoricalCols ${set.mkString("[", ",", "]")} do not exist in inputCols")
}
}
val fields = schema(localInputCols)
fields.foreach { fieldSchema =>
val dataType = fieldSchema.dataType
val fieldName = fieldSchema.name
for (fieldName <- localInputCols) {
val field = SchemaUtils.getSchemaField(schema, fieldName)
val dataType = field.dataType
require(dataType.isInstanceOf[NumericType] ||
dataType.isInstanceOf[StringType] ||
dataType.isInstanceOf[BooleanType],
s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " +
s"${BooleanType.catalogString} or ${StringType.catalogString}. " +
s"Column $fieldName was ${dataType.catalogString}")
}

val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
Expand Down
12 changes: 7 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
require(inputColNames.length == outputColNames.length, s"inputCols(${inputColNames.length})" +
s" and outputCols(${outputColNames.length}) should have the same length")
val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) =>
val inputField = schema(inputCol)
val inputField = SchemaUtils.getSchemaField(schema, inputCol)
SchemaUtils.checkNumericType(schema, inputCol)
StructField(outputCol, inputField.dataType, inputField.nullable)
}
Expand Down Expand Up @@ -155,12 +155,14 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
val spark = dataset.sparkSession

val (inputColumns, _) = getInOutCols()
val cols = inputColumns.map { inputCol =>

val transformedColNames = Array.tabulate(inputColumns.length)(index => s"c_$index")
val cols = inputColumns.zip(transformedColNames).map { case (inputCol, transformedColName) =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
.cast(DoubleType)
.as(inputCol)
.as(transformedColName)
}
val numCols = cols.length

Expand All @@ -176,7 +178,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols.toImmutableArraySeq: _*)
.stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
.stat.approxQuantile(transformedColNames, Array(0.5), $(relativeError))
.map(_.headOption.getOrElse(Double.NaN))

case Imputer.mode =>
Expand Down Expand Up @@ -271,7 +273,7 @@ class ImputerModel private[ml] (

val newCols = inputColumns.map { inputCol =>
val surrogate = surrogates(inputCol)
val inputType = dataset.schema(inputCol).dataType
val inputType = SchemaUtils.getSchemaFieldType(dataset.schema, inputCol)
val ic = col(inputCol).cast(DoubleType)
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val inputFeatures = $(inputCols).map(c => SchemaUtils.getSchemaField(dataset.schema, c))
val featureEncoders = getFeatureEncoders(inputFeatures.toImmutableArraySeq)
val featureAttrs = getFeatureAttrs(inputFeatures.toImmutableArraySeq)

Expand Down Expand Up @@ -102,11 +102,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
Vectors.sparse(size, indices.result(), values.result()).compressed
}

val featureCols = inputFeatures.map { f =>
val featureCols = inputFeatures.zip($(inputCols)).map { case (f, inputCol) =>
f.dataType match {
case DoubleType => dataset(f.name)
case _: VectorUDT => dataset(f.name)
case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType)
case DoubleType => dataset(inputCol)
case _: VectorUDT => dataset(inputCol)
case _: NumericType | BooleanType => dataset(inputCol).cast(DoubleType)
}
}
import org.apache.spark.util.ArrayImplicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
s"output columns ${outputColNames.length}.")

// Input columns must be NumericType.
inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))
inputColNames.foreach { colName =>
SchemaUtils.checkNumericType(schema, colName)
}

// Prepares output columns with proper attributes by examining input columns.
val inputFields = inputColNames.map(schema(_))
val inputFields = inputColNames.map(SchemaUtils.getSchemaField(schema, _))

val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
OneHotEncoderCommon.transformOutputColumnSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -186,6 +187,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
}

var outputFields = schema.fields

inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
SchemaUtils.checkNumericType(schema, inputColName)
require(!schema.fieldNames.contains(outputColName),
Expand All @@ -201,13 +203,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
transformSchema(dataset.schema, logging = true)
val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
if (isSet(inputCols)) {
val quantileColNames = Array.tabulate($(inputCols).length)(index => s"c_$index")
val quantileDataset = dataset.select($(inputCols).zipWithIndex.map {
case (colName, index) => col(colName).alias(quantileColNames(index))
}.toImmutableArraySeq: _*)

val splitsArray = if (isSet(numBucketsArray)) {
val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
(0 to numOfBuckets).map(_.toDouble / numOfBuckets).toArray
}

val probabilityArray = probArrayPerCol.flatten.sorted.distinct
val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols),
val splitsArrayRaw = quantileDataset.stat.approxQuantile(quantileColNames,
probabilityArray, $(relativeError))

splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) =>
Expand All @@ -222,12 +229,13 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
}
}
} else {
dataset.stat.approxQuantile($(inputCols),
quantileDataset.stat.approxQuantile(quantileColNames,
(0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError))
}
bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
} else {
val splits = dataset.stat.approxQuantile($(inputCol),
val quantileDataset = dataset.select(col($(inputCol)).alias("c_0"))
val splits = quantileDataset.stat.approxQuantile("c_0",
(0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError))
bucketizer.setSplits(getDistinctSplits(splits))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,16 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
}

val (inputColNames, outputColNames) = getInOutCols()

val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output Column $outputColName already exists.")
val inputType = schema(inputColName).dataType
val inputType = SchemaUtils.getSchemaFieldType(schema, inputColName)
require(DataTypeUtils.sameType(inputType, ArrayType(StringType)), "Input type must be " +
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
StructField(outputColName, inputType, schema(inputColName).nullable)
StructField(
outputColName, inputType, SchemaUtils.getSchemaField(schema, inputColName).nullable
)
}
StructType(schema.fields ++ newCols)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@

package org.apache.spark.ml.feature

import java.util.ArrayList

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -124,17 +122,15 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
require(outputColNames.distinct.length == outputColNames.length,
s"Output columns should not be duplicate.")

val sparkSession = SparkSession.getActiveSession.get
val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](), schema = schema)
val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
try {
val dtype = transformDataset.col(inputColName).expr.dataType
val dtype = SchemaUtils.getSchemaFieldType(schema, inputColName)
Some(
validateAndTransformField(schema, inputColName, dtype, outputColName)
)
} catch {
case _: AnalysisException =>
case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" =>
if (skipNonExistsCol) {
None
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val schema = dataset.schema

val vectorCols = $(inputCols).filter { c =>
schema(c).dataType match {
dataset.col(c).expr.dataType match {
case _: VectorUDT => true
case _ => false
}
Expand All @@ -97,7 +97,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
dataset, vectorCols.toImmutableArraySeq, $(handleInvalid))

val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
val field = SchemaUtils.getSchemaField(schema, c)
field.dataType match {
case DoubleType =>
val attribute = Attribute.fromStructField(field)
Expand Down Expand Up @@ -145,7 +145,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
dataset(c).expr.dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
Expand All @@ -161,7 +161,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val incorrectColumns = inputColNames.flatMap { name =>
schema(name).dataType match {
SchemaUtils.getSchemaFieldType(schema, name) match {
case _: NumericType | BooleanType => None
case t if t.isInstanceOf[VectorUDT] => None
case other => Some(s"Data type ${other.catalogString} of column $name is not supported.")
Expand Down Expand Up @@ -226,7 +226,8 @@ object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
columns: Seq[String],
handleInvalid: String): Map[String, Int] = {
val groupSizes = columns.map { c =>
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
val field = SchemaUtils.getSchemaField(dataset.schema, c)
c -> AttributeGroup.fromStructField(field).size
}.toMap
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
Expand Down
26 changes: 25 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.util

import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.catalyst.util.AttributeNameParser
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -72,7 +73,7 @@ private[spark] object SchemaUtils {
schema: StructType,
colName: String,
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val actualDataType = getSchemaFieldType(schema, colName)
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " +
Expand Down Expand Up @@ -204,4 +205,27 @@ private[spark] object SchemaUtils {
new ArrayType(FloatType, false))
checkColumnTypes(schema, colName, typeCandidates)
}

/**
* Get schema field.
* @param schema input schema
* @param colName column name, nested column name is supported.
*/
def getSchemaField(schema: StructType, colName: String): StructField = {
val colSplits = AttributeNameParser.parseAttributeName(colName)
var field = schema(colSplits(0))
for (colSplit <- colSplits.slice(1, colSplits.length)) {
field = field.dataType.asInstanceOf[StructType](colSplit)
}
field
}

/**
* Get schema field type.
* @param schema input schema
* @param colName column name, nested column name is supported.
*/
def getSchemaFieldType(schema: StructType, colName: String): DataType = {
getSchemaField(schema, colName).dataType
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.util.ArrayImplicits._

class BinarizerSuite extends MLTest with DefaultReadWriteTest {
Expand Down Expand Up @@ -250,4 +251,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest {
binarizer.transform(df).count()
}
}

test("Binarize nested input") {
val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0)
val dataFrame: DataFrame = data.zip(defaultBinarized).toSeq.toDF("feature", "expected")
.select(struct(col("feature")).as("nest"), col("expected"))

val binarizer: Binarizer = new Binarizer()
.setInputCol("nest.feature")
.setOutputCol("binarized_feature")

val resultDF = binarizer.transform(dataFrame)
resultDF.select("binarized_feature", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after binarization.")
}
}
}
Loading

0 comments on commit e7e0826

Please sign in to comment.