Skip to content

Commit

Permalink
[breaking] [jvm-packages] Remove rabit check point.
Browse files Browse the repository at this point in the history
- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version. [breaking]
- Change the starting version of training continuation in JVM [breaking].

The last item is a bit more subtle, the change aligns JVM packages with Python
packages. After this PR, the second training phrase counts iteration from 0 instead of
from the previous starting iteration.
  • Loading branch information
trivialfis committed Sep 20, 2023
1 parent 38ac52d commit bccf531
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 147 deletions.
18 changes: 0 additions & 18 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);

/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);

/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);


/*!
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
* support is experimental, function signature may change in the future without
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,13 +51,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)

manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 8)
}

test("test cleanUpHigherVersions") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -808,14 +808,6 @@ private String[] getDumpInfo(boolean withStats) throws XGBoostError {
return modelInfos[0];
}

public int getVersion() {
return this.version;
}

public void setVersion(int version) {
this.version = version;
}

/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
Expand All @@ -841,29 +833,6 @@ public byte[] toByteArray(String format) throws XGBoostError {
return bytes[0];
}

/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}

/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}

/**
* Get number of model features.
* @return the number of features.
Expand All @@ -874,6 +843,11 @@ public long getNumFeature() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}

/**
* Internal initialization function.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
Expand Down Expand Up @@ -56,7 +71,6 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
Expand All @@ -66,12 +80,12 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
String eventualPath = getPath(boosterToCheckpoint.getNumBoostedRound());
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
logger.info("saving checkpoint with version " + boosterToCheckpoint.getNumBoostedRound());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -203,7 +203,6 @@ public static Booster trainAndSaveCheckpoint(
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
booster.setParams(params);
Expand All @@ -217,18 +216,11 @@ public static Booster trainAndSaveCheckpoint(
boolean max_direction = false;

// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
}
for (int iter = 0; iter < numRounds; iter++) {
booster.update(dtrain, iter);
saveCheckpoint(booster, iter, checkpointIterations, ecm);

//evaluation
// evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
Expand Down Expand Up @@ -285,7 +277,6 @@ public static Booster trainAndSaveCheckpoint(
Communicator.communicatorPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
}
return booster;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ public final static native int XGBoosterDumpModelExWithFeatures(
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);

public final static native int XGBoosterGetNumFeature(long handle, long[] feature);

public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds);

// communicator functions
public final static native int CommunicatorInit(String[] args);
public final static native int CommunicatorFinalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature

def getVersion: Int = booster.getVersion
def getNumBoostedRound: Long = booster.getNumBoostedRound

/**
* Save model into a raw byte array. Available options are "json", "ubj" and "deprecated".
Expand Down
38 changes: 11 additions & 27 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
JVM_CHECK_CALL(ret);
jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
Expand All @@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound(
JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle)jhandle;
std::int32_t n_rounds{0};
auto ret = XGBoosterBoostedRounds(handle, &n_rounds);
JVM_CHECK_CALL(ret);
jint jn_rounds = n_rounds;
jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: CommunicatorInit
Expand Down
24 changes: 8 additions & 16 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@
package ml.dmlc.xgboost4j.java;

import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Test;

import java.io.ByteArrayInputStream;
Expand All @@ -31,7 +30,7 @@

/**
* test cases for Booster Inplace Predict
*
*
* @author hzx and Sovrn
*/
public class BoosterImplTest {
Expand Down Expand Up @@ -845,14 +844,12 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException {
float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat);

// Save tempBooster to bytestream and load back
int prevVersion = tempBooster.getVersion();
ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray());
tempBooster = XGBoost.loadModel(in);
in.close();
tempBooster.setVersion(prevVersion);

// Continue training using tempBooster
round = 4;
round = 2;
Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster);
float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat);
TestCase.assertTrue(booster1error == booster2error);
Expand Down
27 changes: 2 additions & 25 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
API_END();
}

XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
xgboost_CHECK_C_ARG_PTR(version);
*version = rabit::LoadCheckPoint();
if (*version != 0) {
bst->Configure();
}
API_END();
}

XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto *learner = static_cast<Learner *>(handle);
learner->Configure();
rabit::CheckPoint();
API_END();
}

XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step,
BoosterHandle *out) {
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(out);

auto* learner = static_cast<Learner*>(handle);
auto *learner = static_cast<Learner *>(handle);
bool out_of_bound = false;
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
if (out_of_bound) {
Expand Down

0 comments on commit bccf531

Please sign in to comment.