From bccf531e1ff082bffcf94619b384b42f290e1e9b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 20 Sep 2023 22:50:59 +0800 Subject: [PATCH] [breaking] [jvm-packages] Remove rabit check point. - Add `numBoostedRound` to jvm packages - Remove rabit checkpoint version. [breaking] - Change the starting version of training continuation in JVM [breaking]. The last item is a bit more subtle, the change aligns JVM packages with Python packages. After this PR, the second training phrase counts iteration from 0 instead of from the previous starting iteration. --- include/xgboost/c_api.h | 18 --------- .../ExternalCheckpointManagerSuite.scala | 6 +-- .../java/ml/dmlc/xgboost4j/java/Booster.java | 38 +++---------------- .../java/ExternalCheckpointManager.java | 20 ++++++++-- .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 19 +++------- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 5 ++- .../ml/dmlc/xgboost4j/scala/Booster.scala | 2 +- .../xgboost4j/src/native/xgboost4j.cpp | 38 ++++++------------- jvm-packages/xgboost4j/src/native/xgboost4j.h | 24 ++++-------- .../dmlc/xgboost4j/java/BoosterImplTest.java | 9 ++--- src/c_api/c_api.cc | 27 +------------ 11 files changed, 59 insertions(+), 147 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 9bce616efb84..5df62df55017 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len, XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, const void *buf, bst_ulong len); -/*! - * \brief Initialize the booster from rabit checkpoint. - * This is used in distributed training API. - * \param handle handle - * \param version The output version of the model. - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, - int* version); - -/*! - * \brief Save the current checkpoint to rabit. - * \param handle handle - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle); - - /*! * \brief Save XGBoost's internal configuration into a JSON document. Currently the * support is experimental, function signature may change in the future without diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala index adc9c10687be..c126e95ed86e 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,13 +51,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "4.model") - assert(manager.loadCheckpointAsScalaBooster().getVersion == 4) + assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4) manager.updateCheckpoint(model8._booster) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "8.model") - assert(manager.loadCheckpointAsScalaBooster().getVersion == 8) + assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 8) } test("test cleanUpHigherVersions") { diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 7ed12c704a9f..4fdce62eefbe 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -808,14 +808,6 @@ private String[] getDumpInfo(boolean withStats) throws XGBoostError { return modelInfos[0]; } - public int getVersion() { - return this.version; - } - - public void setVersion(int version) { - this.version = version; - } - /** * Save model into raw byte array. Currently it's using the deprecated format as * default, which will be changed into `ubj` in future releases. @@ -841,29 +833,6 @@ public byte[] toByteArray(String format) throws XGBoostError { return bytes[0]; } - /** - * Load the booster model from thread-local rabit checkpoint. - * This is only used in distributed training. - * @return the stored version number of the checkpoint. - * @throws XGBoostError - */ - int loadRabitCheckpoint() throws XGBoostError { - int[] out = new int[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); - version = out[0]; - return version; - } - - /** - * Save the booster model into thread-local rabit checkpoint and increment the version. - * This is only used in distributed training. - * @throws XGBoostError - */ - void saveRabitCheckpoint() throws XGBoostError { - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); - version += 1; - } - /** * Get number of model features. * @return the number of features. @@ -874,6 +843,11 @@ public long getNumFeature() throws XGBoostError { XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature)); return numFeature[0]; } + public int getNumBoostedRound() throws XGBoostError { + int[] numRound = new int[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound)); + return numRound[0]; + } /** * Internal initialization function. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java index 655b99020313..c0cc2d8911d6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java @@ -1,3 +1,18 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ package ml.dmlc.xgboost4j.java; import java.io.IOException; @@ -56,7 +71,6 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { InputStream in = fs.open(new Path(checkpointPath)); logger.info("loaded checkpoint from " + checkpointPath); Booster booster = XGBoost.loadModel(in); - booster.setVersion(latestVersion); return booster; } else { return null; @@ -66,12 +80,12 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError { List prevModelPaths = getExistingVersions().stream() .map(this::getPath).collect(Collectors.toList()); - String eventualPath = getPath(boosterToCheckpoint.getVersion()); + String eventualPath = getPath(boosterToCheckpoint.getNumBoostedRound()); String tempPath = eventualPath + "-" + UUID.randomUUID(); try (OutputStream out = fs.create(new Path(tempPath), true)) { boosterToCheckpoint.saveModel(out); fs.rename(new Path(tempPath), new Path(eventualPath)); - logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion()); + logger.info("saving checkpoint with version " + boosterToCheckpoint.getNumBoostedRound()); prevModelPaths.stream().forEach(path -> { try { fs.delete(new Path(path), true); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index bcd0b1b11d2f..d42e8e963aa0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -203,7 +203,6 @@ public static Booster trainAndSaveCheckpoint( booster = new Booster(params, allMats); booster.setFeatureNames(dtrain.getFeatureNames()); booster.setFeatureTypes(dtrain.getFeatureTypes()); - booster.loadRabitCheckpoint(); } else { // Start training on an existing booster booster.setParams(params); @@ -217,18 +216,11 @@ public static Booster trainAndSaveCheckpoint( boolean max_direction = false; // begin to train - for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { - if (booster.getVersion() % 2 == 0) { - if (obj != null) { - booster.update(dtrain, obj); - } else { - booster.update(dtrain, iter); - } - saveCheckpoint(booster, iter, checkpointIterations, ecm); - booster.saveRabitCheckpoint(); - } + for (int iter = 0; iter < numRounds; iter++) { + booster.update(dtrain, iter); + saveCheckpoint(booster, iter, checkpointIterations, ecm); - //evaluation + // evaluation if (evalMats.length > 0) { float[] metricsOut = new float[evalMats.length]; String evalInfo; @@ -285,7 +277,6 @@ public static Booster trainAndSaveCheckpoint( Communicator.communicatorPrint(evalInfo + '\n'); } } - booster.saveRabitCheckpoint(); } return booster; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index eabbf29ba945..236d53e900a9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -140,10 +140,11 @@ public final static native int XGBoosterDumpModelExWithFeatures( public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterSetAttr(long handle, String key, String value); - public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); - public final static native int XGBoosterSaveRabitCheckpoint(long handle); + public final static native int XGBoosterGetNumFeature(long handle, long[] feature); + public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds); + // communicator functions public final static native int CommunicatorInit(String[] args); public final static native int CommunicatorFinalize(); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 31be86898e5a..c288bfab19fb 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) @throws(classOf[XGBoostError]) def getNumFeature: Long = booster.getNumFeature - def getVersion: Int = booster.getVersion + def getNumBoostedRound: Long = booster.getNumBoostedRound /** * Save model into a raw byte array. Available options are "json", "ubj" and "deprecated". diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 821b1ebff054..332b1a12774b 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr return ret; } -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterLoadRabitCheckpoint - * Signature: (J[I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint - (JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - int version; - int ret = XGBoosterLoadRabitCheckpoint(handle, &version); - JVM_CHECK_CALL(ret); - jint jversion = version; - jenv->SetIntArrayRegion(jout, 0, 1, &jversion); - return ret; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterSaveRabitCheckpoint - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - BoosterHandle handle = (BoosterHandle) jhandle; - return XGBoosterSaveRabitCheckpoint(handle); -} - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetNumFeature @@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound( + JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) { + BoosterHandle handle = (BoosterHandle)jhandle; + std::int32_t n_rounds{0}; + auto ret = XGBoosterBoostedRounds(handle, &n_rounds); + JVM_CHECK_CALL(ret); + jint jn_rounds = n_rounds; + jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 87ff6d30db6a..cc4ad53d4e4c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr (JNIEnv *, jclass, jlong, jstring, jstring); -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterLoadRabitCheckpoint - * Signature: (J[I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint - (JNIEnv *, jclass, jlong, jintArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterSaveRabitCheckpoint - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint - (JNIEnv *, jclass, jlong); - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetNumFeature @@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature (JNIEnv *, jclass, jlong, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetNumBoostedRound + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound + (JNIEnv *, jclass, jlong, jintArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index c7508b20d8ea..b686ddbed858 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package ml.dmlc.xgboost4j.java; import junit.framework.TestCase; -import org.junit.Assert; import org.junit.Test; import java.io.ByteArrayInputStream; @@ -31,7 +30,7 @@ /** * test cases for Booster Inplace Predict - * + * * @author hzx and Sovrn */ public class BoosterImplTest { @@ -845,14 +844,12 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException { float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat); // Save tempBooster to bytestream and load back - int prevVersion = tempBooster.getVersion(); ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray()); tempBooster = XGBoost.loadModel(in); in.close(); - tempBooster.setVersion(prevVersion); // Continue training using tempBooster - round = 4; + round = 2; Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); TestCase.assertTrue(booster1error == booster2error); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 2b0862d4945a..f6ab8d4dfe32 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, - int* version) { - API_BEGIN(); - CHECK_HANDLE(); - auto* bst = static_cast(handle); - xgboost_CHECK_C_ARG_PTR(version); - *version = rabit::LoadCheckPoint(); - if (*version != 0) { - bst->Configure(); - } - API_END(); -} - -XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) { - API_BEGIN(); - CHECK_HANDLE(); - auto *learner = static_cast(handle); - learner->Configure(); - rabit::CheckPoint(); - API_END(); -} - -XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, - int end_layer, int step, +XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step, BoosterHandle *out) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(out); - auto* learner = static_cast(handle); + auto *learner = static_cast(handle); bool out_of_bound = false; auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound); if (out_of_bound) {