diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index 275263a34ef5..6ab9f679d706 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -93,7 +93,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin { selectedCols.append(col) } val input = dataset.select(selectedCols.toArray: _*) - estimator.repartitionIfNeeded(input) + val repartitioned = estimator.repartitionIfNeeded(input) + estimator.sortPartitionIfNeeded(repartitioned) } // visible for testing diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala index 97f54b601eb3..c84a8b51a146 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala @@ -16,14 +16,14 @@ package ml.dmlc.xgboost4j.scala.spark -import ai.rapids.cudf.Table +import ai.rapids.cudf.{OrderByArg, Table} import ml.dmlc.xgboost4j.java.CudfColumnBatch import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix, XGBoost => ScalaXGBoost} import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite import ml.dmlc.xgboost4j.scala.rapids.spark.SparkSessionHolder.withSparkSession import ml.dmlc.xgboost4j.scala.spark.Utils.withResource import org.apache.spark.ml.linalg.DenseVector -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.SparkConf import java.io.File @@ -94,7 +94,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { } // spark.rapids.sql.enabled is not set explicitly, default to true - withSparkSession(new SparkConf(), spark => {checkIsEnabled(spark, true)}) + withSparkSession(new SparkConf(), spark => { + checkIsEnabled(spark, true) + }) // set spark.rapids.sql.enabled to false withCpuSparkSession() { spark => @@ -503,6 +505,109 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { } } + test("The group col should be sorted in each partition") { + withGpuSparkSession() { spark => + import spark.implicits._ + val df = Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3") + + val xgboostParams: Map[String, Any] = Map( + "device" -> "cuda", + "objective" -> "rank:ndcg" + ) + val features = Array("c1", "c2", "c3") + val label = "label" + val group = "group" + + val ranker = new XGBoostRanker(xgboostParams) + .setFeaturesCol(features) + .setLabelCol(label) + .setNumWorkers(1) + .setNumRound(1) + .setGroupCol(group) + .setDevice("cuda") + + val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df) + processedDf.rdd.foreachPartition { iter => { + var prevGroup = Int.MinValue + while (iter.hasNext) { + val curr = iter.next() + val group = curr.asInstanceOf[Row].getAs[Int](1) + assert(prevGroup <= group) + prevGroup = group + } + } + } + } + } + + test("Ranker: XGBoost-Spark should match xgboost4j") { + withGpuSparkSession() { spark => + import spark.implicits._ + + val trainPath = writeFile(Ranking.train.toDF("label", "weight", "group", "c1", "c2", "c3")) + val testPath = writeFile(Ranking.test.toDF("label", "weight", "group", "c1", "c2", "c3")) + + val df = spark.read.parquet(trainPath) + val testdf = spark.read.parquet(testPath) + + val features = Array("c1", "c2", "c3") + val featuresIndices = features.map(df.schema.fieldIndex) + val label = "label" + val group = "group" + + val numRound = 100 + val xgboostParams: Map[String, Any] = Map( + "device" -> "cuda", + "objective" -> "rank:ndcg" + ) + + val ranker = new XGBoostRanker(xgboostParams) + .setFeaturesCol(features) + .setLabelCol(label) + .setNumRound(numRound) + .setLeafPredictionCol("leaf") + .setContribPredictionCol("contrib") + .setGroupCol(group) + .setDevice("cuda") + + val xgb4jModel = withResource(new GpuColumnBatch( + Table.readParquet(new File(trainPath) + ).orderBy(OrderByArg.asc(df.schema.fieldIndex(group))))) { batch => + val cb = new CudfColumnBatch(batch.select(featuresIndices), + batch.select(df.schema.fieldIndex(label)), null, null, + batch.select(df.schema.fieldIndex(group))) + val qdm = new QuantileDMatrix(Seq(cb).iterator, ranker.getMissing, + ranker.getMaxBins, ranker.getNthread) + ScalaXGBoost.train(qdm, xgboostParams, numRound) + } + + val (xgb4jLeaf, xgb4jContrib, xgb4jPred) = withResource(new GpuColumnBatch( + Table.readParquet(new File(testPath)))) { batch => + val cb = new CudfColumnBatch(batch.select(featuresIndices), null, null, null, null + ) + val qdm = new DMatrix(cb, ranker.getMissing, ranker.getNthread) + (xgb4jModel.predictLeaf(qdm), xgb4jModel.predictContrib(qdm), + xgb4jModel.predict(qdm)) + } + + val rows = ranker.fit(df).transform(testdf).collect() + + // Check Leaf + val xgbSparkLeaf = rows.map(row => row.getAs[DenseVector]("leaf").toArray.map(_.toFloat)) + checkEqual(xgb4jLeaf, xgbSparkLeaf) + + // Check contrib + val xgbSparkContrib = rows.map(row => + row.getAs[DenseVector]("contrib").toArray.map(_.toFloat)) + checkEqual(xgb4jContrib, xgbSparkContrib) + + // Check prediction + val xgbSparkPred = rows.map(row => + Array(row.getAs[Double]("prediction").toFloat)) + checkEqual(xgb4jPred, xgbSparkPred) + } + } + def writeFile(df: Dataset[_]): String = { def listFiles(directory: String): Array[String] = { val dir = new File(directory) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala index 49c790fd0a00..043385137af6 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala @@ -81,6 +81,6 @@ object Regression extends TrainTestData { } object Ranking extends TrainTestData { - val train = generateRankDataset(300, 10, 555) - val test = generateRankDataset(150, 10, 556) + val train = generateRankDataset(300, 10, 12, 555) + val test = generateRankDataset(150, 10, 12, 556) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index aaf2e07a7091..6978b82da8fc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -134,6 +134,15 @@ private[spark] trait XGBoostEstimator[ } } + /** + * Sort partition for Ranker issue. + * @param dataset + * @return + */ + private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = { + dataset + } + /** * Build the columns indices. */ @@ -198,10 +207,10 @@ private[spark] trait XGBoostEstimator[ case p: HasGroupCol => selectCol(p.groupCol, IntegerType) case _ => } - val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*)) - - val columnIndices = buildColumnIndices(input.schema) - (input, columnIndices) + val repartitioned = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*)) + val sorted = sortPartitionIfNeeded(repartitioned) + val columnIndices = buildColumnIndices(sorted.schema) + (sorted, columnIndices) } /** visible for testing */ diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala new file mode 100644 index 000000000000..6e020560e6f6 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala @@ -0,0 +1,124 @@ +/* + Copyright (c) 2024 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader} +import org.apache.spark.ml.xgboost.SparkUtils +import org.apache.spark.sql.Dataset +import ml.dmlc.xgboost4j.scala.Booster +import ml.dmlc.xgboost4j.scala.spark.XGBoostRanker._uid +import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol +import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + +class XGBoostRanker(override val uid: String, + private val xgboostParams: Map[String, Any]) + extends Predictor[Vector, XGBoostRanker, XGBoostRankerModel] + with XGBoostEstimator[XGBoostRanker, XGBoostRankerModel] with HasGroupCol { + + def this() = this(_uid, Map[String, Any]()) + + def this(uid: String) = this(uid, Map[String, Any]()) + + def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams) + + def setGroupCol(value: String): XGBoostRanker = set(groupCol, value) + + xgboost2SparkParams(xgboostParams) + + /** + * Validate the parameters before training, throw exception if possible + */ + override protected[spark] def validate(dataset: Dataset[_]): Unit = { + super.validate(dataset) + + require(isDefinedNonEmpty(groupCol), "groupCol needs to be set") + + // If the objective is set explicitly, it must be in RANKER_OBJS + if (isSet(objective)) { + val tmpObj = getObjective + require(RANKER_OBJS.contains(tmpObj), + s"Wrong objective for XGBoostRanker, supported objs: ${RANKER_OBJS.mkString(",")}") + } else { + setObjective("rank:ndcg") + } + } + + /** + * Sort partition for Ranker issue. + * + * @param dataset + * @return + */ + override private[spark] def sortPartitionIfNeeded(dataset: Dataset[_]) = { + dataset.sortWithinPartitions(getGroupCol) + } + + override protected def createModel( + booster: Booster, + summary: XGBoostTrainingSummary): XGBoostRankerModel = { + new XGBoostRankerModel(uid, booster, Option(summary)) + } + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = + SparkUtils.appendColumn(schema, $(predictionCol), DoubleType) +} + +object XGBoostRanker extends DefaultParamsReadable[XGBoostRanker] { + private val _uid = Identifiable.randomUID("xgbranker") +} + +class XGBoostRankerModel private[ml](val uid: String, + val nativeBooster: Booster, + val summary: Option[XGBoostTrainingSummary] = None) + extends PredictionModel[Vector, XGBoostRankerModel] + with RankerRegressorBaseModel[XGBoostRankerModel] with HasGroupCol { + + def this(uid: String) = this(uid, null) + + def setGroupCol(value: String): XGBoostRankerModel = set(groupCol, value) + + override def copy(extra: ParamMap): XGBoostRankerModel = { + val newModel = copyValues(new XGBoostRankerModel(uid, nativeBooster, summary), extra) + newModel.setParent(parent) + } + + override def predict(features: Vector): Double = { + val values = predictSingleInstance(features) + values(0) + } +} + +object XGBoostRankerModel extends MLReadable[XGBoostRankerModel] { + override def read: MLReader[XGBoostRankerModel] = new ModelReader + + private class ModelReader extends XGBoostModelReader[XGBoostRankerModel] { + override def load(path: String): XGBoostRankerModel = { + val xgbModel = loadBooster(path) + val meta = SparkUtils.loadMetadata(path, sc) + val model = new XGBoostRankerModel(meta.uid, xgbModel, None) + meta.getAndSetParams(model) + model + } + } +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala new file mode 100644 index 000000000000..81a770bfe327 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala @@ -0,0 +1,309 @@ +/* + Copyright (c) 2024 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.ml.linalg.{DenseVector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.scalatest.funsuite.AnyFunSuite + +import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} +import ml.dmlc.xgboost4j.scala.spark.Regression.Ranking +import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS +import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams + +class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite { + + test("XGBoostRanker copy") { + val ranker = new XGBoostRanker().setNthread(2).setNumWorkers(10) + val rankertCopied = ranker.copy(ParamMap.empty) + + assert(ranker.uid === rankertCopied.uid) + assert(ranker.getNthread === rankertCopied.getNthread) + assert(ranker.getNumWorkers === ranker.getNumWorkers) + } + + test("XGBoostRankerModel copy") { + val model = new XGBoostRankerModel("hello").setNthread(2).setNumWorkers(10) + val modelCopied = model.copy(ParamMap.empty) + assert(model.uid === modelCopied.uid) + assert(model.getNthread === modelCopied.getNthread) + assert(model.getNumWorkers === modelCopied.getNumWorkers) + } + + test("read/write") { + val trainDf = smallGroupVector + val xgbParams: Map[String, Any] = Map( + "max_depth" -> 5, + "eta" -> 0.2, + "objective" -> "rank:ndcg" + ) + + def check(xgboostParams: XGBoostParams[_]): Unit = { + assert(xgboostParams.getMaxDepth === 5) + assert(xgboostParams.getEta === 0.2) + assert(xgboostParams.getObjective === "rank:ndcg") + } + + val rankerPath = new File(tempDir.toFile, "ranker").getPath + val ranker = new XGBoostRanker(xgbParams).setNumRound(1).setGroupCol("group") + check(ranker) + assert(ranker.getGroupCol === "group") + + ranker.write.overwrite().save(rankerPath) + val loadedRanker = XGBoostRanker.load(rankerPath) + check(loadedRanker) + assert(loadedRanker.getGroupCol === "group") + + val model = loadedRanker.fit(trainDf) + check(model) + assert(model.getGroupCol === "group") + + val modelPath = new File(tempDir.toFile, "model").getPath + model.write.overwrite().save(modelPath) + val modelLoaded = XGBoostRankerModel.load(modelPath) + check(modelLoaded) + assert(modelLoaded.getGroupCol === "group") + } + + test("validate") { + val trainDf = smallGroupVector + val ranker = new XGBoostRanker() + // must define group column + intercept[IllegalArgumentException]( + ranker.validate(trainDf) + ) + val ranker1 = new XGBoostRanker().setGroupCol("group") + ranker1.validate(trainDf) + assert(ranker1.getObjective === "rank:ndcg") + } + + test("XGBoostRankerModel transformed schema") { + val trainDf = smallGroupVector + val ranker = new XGBoostRanker().setGroupCol("group").setNumRound(1) + val model = ranker.fit(trainDf) + var out = model.transform(trainDf) + // Transform should not discard the other columns of the transforming dataframe + Seq("label", "group", "margin", "weight", "features").foreach { v => + assert(out.schema.names.contains(v)) + } + // Ranker does not have extra columns + Seq("rawPrediction", "probability").foreach { v => + assert(!out.schema.names.contains(v)) + } + assert(out.schema.names.contains("prediction")) + assert(out.schema.names.length === 6) + model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib") + out = model.transform(trainDf) + assert(out.schema.names.contains("leaf")) + assert(out.schema.names.contains("contrib")) + } + + test("Supported objectives") { + val ranker = new XGBoostRanker().setGroupCol("group") + val df = smallGroupVector + RANKER_OBJS.foreach { obj => + ranker.setObjective(obj) + ranker.validate(df) + } + + ranker.setObjective("binary:logistic") + intercept[IllegalArgumentException]( + ranker.validate(df) + ) + } + + test("The group col should be sorted in each partition") { + val trainingDF = buildDataFrameWithGroup(Ranking.train) + + val ranker = new XGBoostRanker() + .setNumRound(1) + .setNumWorkers(numWorkers) + .setGroupCol("group") + + val (df, _) = ranker.preprocess(trainingDF) + df.rdd.foreachPartition { iter => { + var prevGroup = Int.MinValue + while (iter.hasNext) { + val curr = iter.next() + val group = curr.asInstanceOf[Row].getAs[Int](2) + assert(prevGroup <= group) + prevGroup = group + } + }} + } + + private def runLengthEncode(input: Seq[Int]): Seq[Int] = { + if (input.isEmpty) return Seq(0) + + input.indices + .filter(i => i == 0 || input(i) != input(i - 1)) :+ input.length + } + + private def runRanker(ranker: XGBoostRanker, dataset: Dataset[_]): (Array[Float], Array[Int]) = { + val (df, indices) = ranker.preprocess(dataset) + val rdd = ranker.toRdd(df, indices) + val result = rdd.mapPartitions { iter => + if (iter.hasNext) { + val watches = iter.next() + val dm = watches.toMap(Utils.TRAIN_NAME) + val weight = dm.getWeight + val group = dm.getGroup + watches.delete() + Iterator.single((weight, group)) + } else { + Iterator.empty + } + }.collect() + + val weight: ArrayBuffer[Float] = ArrayBuffer.empty + val group: ArrayBuffer[Int] = ArrayBuffer.empty + + for (row <- result) { + weight.append(row._1: _*) + group.append(row._2: _*) + } + (weight.toArray, group.toArray) + } + + Seq(None, Some("weight")).foreach { weightCol => { + val msg = weightCol.map(_ => "with weight").getOrElse("without weight") + test(s"to RDD watches with group $msg") { + // One instance without setting weight + var df = ss.createDataFrame(sc.parallelize(Seq( + (1.0, 0, 10, Vectors.dense(Array(1.0, 2.0, 3.0))) + ))).toDF("label", "group", "weight", "features") + + val ranker = new XGBoostRanker() + .setLabelCol("label") + .setFeaturesCol("features") + .setGroupCol("group") + .setNumWorkers(1) + + weightCol.foreach(ranker.setWeightCol) + + val (weights, groupSize) = runRanker(ranker, df) + val expectedWeight = weightCol.map(_ => Array(10.0f)).getOrElse(Array(1.0f)) + assert(weights === expectedWeight) + assert(groupSize === runLengthEncode(Seq(0))) + + df = ss.createDataFrame(sc.parallelize(Seq( + (1.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))), + (2.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))), + (1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))), + (0.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))), + (1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))), + (2.0, 2, 7, Vectors.dense(Array(1.0, 2.0, 3.0))) + ))).toDF("label", "group", "weight", "features") + + val groups = Array(1, 1, 0, 1, 0, 2).sorted + val (weights1, groupSize1) = runRanker(ranker, df) + val expectedWeight1 = weightCol.map(_ => Array(5.0f, 2.0f, 7.0f)) + .getOrElse(groups.distinct.map(_ => 1.0f)) + + assert(groupSize1 === runLengthEncode(groups)) + assert(weights1 === expectedWeight1) + } + } + } + + test("XGBoost-Spark output should match XGBoost4j") { + val trainingDM = new DMatrix(Ranking.train.iterator) + val weights = Ranking.trainGroups.distinct.map(_ => 1.0f).toArray + trainingDM.setQueryId(Ranking.trainGroups.toArray) + trainingDM.setWeight(weights) + + val testDM = new DMatrix(Ranking.test.iterator) + val trainingDF = buildDataFrameWithGroup(Ranking.train) + val testDF = buildDataFrameWithGroup(Ranking.test) + val paramMap = Map("objective" -> "rank:ndcg") + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap) + } + + test("XGBoost-Spark output with weight should match XGBoost4j") { + val trainingDM = new DMatrix(Ranking.trainWithWeight.iterator) + trainingDM.setQueryId(Ranking.trainGroups.toArray) + trainingDM.setWeight(Ranking.trainGroups.distinct.map(_.toFloat).toArray) + + val testDM = new DMatrix(Ranking.test.iterator) + val trainingDF = buildDataFrameWithGroup(Ranking.trainWithWeight) + val testDF = buildDataFrameWithGroup(Ranking.test) + val paramMap = Map("objective" -> "rank:ndcg") + checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, + 5, paramMap, Some("weight")) + } + + private def checkResultsWithXGBoost4j( + trainingDM: DMatrix, + testDM: DMatrix, + trainingDF: DataFrame, + testDF: DataFrame, + round: Int = 5, + xgbParams: Map[String, Any] = Map.empty, + weightCol: Option[String] = None): Unit = { + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "base_score" -> 0.5, + "max_bin" -> 16) ++ xgbParams + val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round) + + val ranker = new XGBoostRanker(paramMap) + .setNumRound(round) + // If we use multi workers to train the ranking, the result probably will be different + .setNumWorkers(1) + .setLeafPredictionCol("leaf") + .setContribPredictionCol("contrib") + .setGroupCol("group") + weightCol.foreach(weight => ranker.setWeightCol(weight)) + + def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = { + assert(left.size === right.size) + left.zipWithIndex.foreach { case (leftValue, index) => + assert(leftValue.sameElements(right(index))) + } + } + + val xgbSparkModel = ranker.fit(trainingDF) + val rows = xgbSparkModel.transform(testDF).collect() + + // Check Leaf + val xgb4jLeaf = xgb4jModel.predictLeaf(testDM) + val xgbSparkLeaf = rows.map(row => + (row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap + checkEqual(xgb4jLeaf, xgbSparkLeaf) + + // Check contrib + val xgb4jContrib = xgb4jModel.predictContrib(testDM) + val xgbSparkContrib = rows.map(row => + (row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap + checkEqual(xgb4jContrib, xgbSparkContrib) + + // Check prediction + val xgb4jPred = xgb4jModel.predict(testDM) + val xgbSparkPred = rows.map(row => { + val pred = row.getAs[Double]("prediction").toFloat + (row.getAs[Int]("id"), Array(pred)) + }).toMap + checkEqual(xgb4jPred, xgbSparkPred) + } + +}