diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 9afb88afec932..0103282c269dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -189,82 +189,66 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) validateSchema(schema, fitting = true) } + private def extractLabel(name: String, targetType: String): Column = { + val c = col(name).cast(DoubleType) + targetType match { + case TargetEncoder.TARGET_BINARY => + when(c === 0 || c === 1, c) + .when(c.isNull || c.isNaN, c) + .otherwise(raise_error( + concat(lit("Labels for TARGET_BINARY must be {0, 1}, but got "), c))) + + case TargetEncoder.TARGET_CONTINUOUS => c + } + } + + private def extractValue(name: String): Column = { + val c = col(name).cast(DoubleType) + when(c >= 0 && c === c.cast(IntegerType), c) + .when(c.isNull, lit(TargetEncoder.NULL_CATEGORY)) + .when(c.isNaN, raise_error(lit("Values MUST NOT be NaN"))) + .otherwise(raise_error( + concat(lit("Values MUST be non-negative integers, but got "), c))) + } + @Since("4.0.0") override def fit(dataset: Dataset[_]): TargetEncoderModel = { validateSchema(dataset.schema, fitting = true) + val numFeatures = inputFeatures.length - // stats: Array[Map[category, (counter,stat)]] - val stats = dataset - .select((inputFeatures :+ $(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*) - .rdd.treeAggregate( - Array.fill(inputFeatures.length) { - Map.empty[Double, (Double, Double)] - })( - - (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { - val label = row.getDouble(inputFeatures.length) - if (!label.equals(Double.NaN)) { - inputFeatures.indices.map { - feature => { - val category: Double = { - if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY // null category - else { - val value = row.getDouble(feature) - if (value < 0.0 || value != value.toInt) throw new SparkException( - s"Values from column ${inputFeatures(feature)} must be indices, " + - s"but got $value.") - else value // non-null category - } - } - val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) - val (global_count, global_stat) = - agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) - $(targetType) match { - case TargetEncoder.TARGET_BINARY => // counting - if (label == 1.0) { - // positive => increment both counters for current & unseen categories - agg(feature) + - (category -> (1 + class_count, 1 + class_stat)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) - } else if (label == 0.0) { - // negative => increment only global counter for current & unseen categories - agg(feature) + - (category -> (1 + class_count, class_stat)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) - } else throw new SparkException( - s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") - case TargetEncoder.TARGET_CONTINUOUS => // incremental mean - // increment counter and iterate on mean for current & unseen categories - agg(feature) + - (category -> (1 + class_count, - class_stat + ((label - class_stat) / (1 + class_count)))) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, - global_stat + ((label - global_stat) / (1 + global_count)))) - } - } - }.toArray - } else agg // ignore NaN-labeled observations - } else agg, // ignore null-labeled observations - - (agg1, agg2) => inputFeatures.indices.map { - feature => { - val categories = agg1(feature).keySet ++ agg2(feature).keySet - categories.map(category => - category -> { - val (counter1, stat1) = agg1(feature).getOrElse(category, (0.0, 0.0)) - val (counter2, stat2) = agg2(feature).getOrElse(category, (0.0, 0.0)) - $(targetType) match { - case TargetEncoder.TARGET_BINARY => (counter1 + counter2, stat1 + stat2) - case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2, - ((counter1 * stat1) + (counter2 * stat2)) / (counter1 + counter2)) - } - }).toMap - } - }.toArray) + // Append the unseen category, for global stats computation + val arrayCol = array( + (inputFeatures.map(v => extractValue(v)) :+ lit(TargetEncoder.UNSEEN_CATEGORY)) + .toIndexedSeq: _*) + + val checked = dataset + .select(extractLabel($(labelCol), $(targetType)).as("label"), arrayCol.as("array")) + .where(!col("label").isNaN && !col("label").isNull) + .select(col("label"), posexplode(col("array")).as(Seq("index", "value"))) + val statCol = $(targetType) match { + case TargetEncoder.TARGET_BINARY => count_if(col("label") === 1) + case TargetEncoder.TARGET_CONTINUOUS => avg(col("label")) + } + val aggregated = checked + .groupBy("index", "value") + .agg(count(lit(1)).cast(DoubleType).as("count"), statCol.cast(DoubleType).as("stat")) + // stats: Array[Map[category, (counter,stat)]] + val stats = Array.fill(numFeatures)(collection.mutable.Map.empty[Double, (Double, Double)]) + aggregated.select("index", "value", "count", "stat").collect() + .foreach { case Row(index: Int, value: Double, count: Double, stat: Double) => + if (index < numFeatures) { + // Assign the per-category stats to the corresponding feature + stats(index).update(value, (count, stat)) + } else { + // Assign the global stats to all features + assert(value == TargetEncoder.UNSEEN_CATEGORY) + stats.foreach { s => s.update(value, (count, stat)) } + } + } - val model = new TargetEncoderModel(uid, stats).setParent(this) + val model = new TargetEncoderModel(uid, stats.map(_.toMap)).setParent(this) copyValues(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index 869be94ff1273..6bb3ce224a2e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -376,15 +376,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val df_noindex = spark .createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema) - val ex = intercept[SparkException] { + val ex = intercept[SparkRuntimeException] { val model = encoder.fit(df_noindex) print(model.stats) } - assert(ex.isInstanceOf[SparkException]) assert(ex.getMessage.contains( - "Values from column input3 must be indices, but got 5.1")) - + "Values MUST be non-negative integers, but got 5.1")) } test("TargetEncoder - invalid label") { @@ -407,7 +405,6 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { model.stats.zip(expected_stats_continuous).foreach{ case (actual, expected) => assert(actual.equals(expected)) } - } test("TargetEncoder - non-binary labels") { @@ -423,15 +420,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val df_non_binary = spark .createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema) - val ex = intercept[SparkException] { + val ex = intercept[SparkRuntimeException] { val model = encoder.fit(df_non_binary) print(model.stats) } - assert(ex.isInstanceOf[SparkException]) assert(ex.getMessage.contains( - "Values from column label must be binary (0,1) but got 2.0")) - + "Labels for TARGET_BINARY must be {0, 1}, but got 2.0")) } test("TargetEncoder - features renamed") {