Skip to content

Commit

Permalink
[SPARK-50267][ML] Improve TargetEncoder.fit with DataFrame APIs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Improve `TargetEncoder.fit` to be based on DataFrame APIs

### Why are the changes needed?
1, simplify the implementation;
2, with DataFrame APIs, it will benefit from the optimization from Spark SQL

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48797 from zhengruifeng/target_encoder_fit.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Nov 10, 2024
1 parent 8f5d8d4 commit 2cdbede
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 78 deletions.
122 changes: 53 additions & 69 deletions mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,82 +189,66 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String)
validateSchema(schema, fitting = true)
}

private def extractLabel(name: String, targetType: String): Column = {
val c = col(name).cast(DoubleType)
targetType match {
case TargetEncoder.TARGET_BINARY =>
when(c === 0 || c === 1, c)
.when(c.isNull || c.isNaN, c)
.otherwise(raise_error(
concat(lit("Labels for TARGET_BINARY must be {0, 1}, but got "), c)))

case TargetEncoder.TARGET_CONTINUOUS => c
}
}

private def extractValue(name: String): Column = {
val c = col(name).cast(DoubleType)
when(c >= 0 && c === c.cast(IntegerType), c)
.when(c.isNull, lit(TargetEncoder.NULL_CATEGORY))
.when(c.isNaN, raise_error(lit("Values MUST NOT be NaN")))
.otherwise(raise_error(
concat(lit("Values MUST be non-negative integers, but got "), c)))
}

@Since("4.0.0")
override def fit(dataset: Dataset[_]): TargetEncoderModel = {
validateSchema(dataset.schema, fitting = true)
val numFeatures = inputFeatures.length

// stats: Array[Map[category, (counter,stat)]]
val stats = dataset
.select((inputFeatures :+ $(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*)
.rdd.treeAggregate(
Array.fill(inputFeatures.length) {
Map.empty[Double, (Double, Double)]
})(

(agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) {
val label = row.getDouble(inputFeatures.length)
if (!label.equals(Double.NaN)) {
inputFeatures.indices.map {
feature => {
val category: Double = {
if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY // null category
else {
val value = row.getDouble(feature)
if (value < 0.0 || value != value.toInt) throw new SparkException(
s"Values from column ${inputFeatures(feature)} must be indices, " +
s"but got $value.")
else value // non-null category
}
}
val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0))
val (global_count, global_stat) =
agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0))
$(targetType) match {
case TargetEncoder.TARGET_BINARY => // counting
if (label == 1.0) {
// positive => increment both counters for current & unseen categories
agg(feature) +
(category -> (1 + class_count, 1 + class_stat)) +
(TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat))
} else if (label == 0.0) {
// negative => increment only global counter for current & unseen categories
agg(feature) +
(category -> (1 + class_count, class_stat)) +
(TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat))
} else throw new SparkException(
s"Values from column ${getLabelCol} must be binary (0,1) but got $label.")
case TargetEncoder.TARGET_CONTINUOUS => // incremental mean
// increment counter and iterate on mean for current & unseen categories
agg(feature) +
(category -> (1 + class_count,
class_stat + ((label - class_stat) / (1 + class_count)))) +
(TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count,
global_stat + ((label - global_stat) / (1 + global_count))))
}
}
}.toArray
} else agg // ignore NaN-labeled observations
} else agg, // ignore null-labeled observations

(agg1, agg2) => inputFeatures.indices.map {
feature => {
val categories = agg1(feature).keySet ++ agg2(feature).keySet
categories.map(category =>
category -> {
val (counter1, stat1) = agg1(feature).getOrElse(category, (0.0, 0.0))
val (counter2, stat2) = agg2(feature).getOrElse(category, (0.0, 0.0))
$(targetType) match {
case TargetEncoder.TARGET_BINARY => (counter1 + counter2, stat1 + stat2)
case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2,
((counter1 * stat1) + (counter2 * stat2)) / (counter1 + counter2))
}
}).toMap
}
}.toArray)
// Append the unseen category, for global stats computation
val arrayCol = array(
(inputFeatures.map(v => extractValue(v)) :+ lit(TargetEncoder.UNSEEN_CATEGORY))
.toIndexedSeq: _*)

val checked = dataset
.select(extractLabel($(labelCol), $(targetType)).as("label"), arrayCol.as("array"))
.where(!col("label").isNaN && !col("label").isNull)
.select(col("label"), posexplode(col("array")).as(Seq("index", "value")))

val statCol = $(targetType) match {
case TargetEncoder.TARGET_BINARY => count_if(col("label") === 1)
case TargetEncoder.TARGET_CONTINUOUS => avg(col("label"))
}
val aggregated = checked
.groupBy("index", "value")
.agg(count(lit(1)).cast(DoubleType).as("count"), statCol.cast(DoubleType).as("stat"))

// stats: Array[Map[category, (counter,stat)]]
val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double, (Double, Double)])
aggregated.select("index", "value", "count", "stat").collect()
.foreach { case Row(index: Int, value: Double, count: Double, stat: Double) =>
if (index < numFeatures) {
// Assign the per-category stats to the corresponding feature
stats(index).update(value, (count, stat))
} else {
// Assign the global stats to all features
assert(value == TargetEncoder.UNSEEN_CATEGORY)
stats.foreach { s => s.update(value, (count, stat)) }
}
}

val model = new TargetEncoderModel(uid, stats).setParent(this)
val model = new TargetEncoderModel(uid, stats.map(_.toMap)).setParent(this)
copyValues(model)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest {
val df_noindex = spark
.createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema)

val ex = intercept[SparkException] {
val ex = intercept[SparkRuntimeException] {
val model = encoder.fit(df_noindex)
print(model.stats)
}

assert(ex.isInstanceOf[SparkException])
assert(ex.getMessage.contains(
"Values from column input3 must be indices, but got 5.1"))

"Values MUST be non-negative integers, but got 5.1"))
}

test("TargetEncoder - invalid label") {
Expand All @@ -407,7 +405,6 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest {
model.stats.zip(expected_stats_continuous).foreach{
case (actual, expected) => assert(actual.equals(expected))
}

}

test("TargetEncoder - non-binary labels") {
Expand All @@ -423,15 +420,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest {
val df_non_binary = spark
.createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema)

val ex = intercept[SparkException] {
val ex = intercept[SparkRuntimeException] {
val model = encoder.fit(df_non_binary)
print(model.stats)
}

assert(ex.isInstanceOf[SparkException])
assert(ex.getMessage.contains(
"Values from column label must be binary (0,1) but got 2.0"))

"Labels for TARGET_BINARY must be {0, 1}, but got 2.0"))
}

test("TargetEncoder - features renamed") {
Expand Down

0 comments on commit 2cdbede

Please sign in to comment.