Skip to content

Commit

Permalink
Refs. dmlc#7547. Using as default serialization format.
Browse files Browse the repository at this point in the history
  • Loading branch information
dotbg committed Jul 12, 2023
1 parent a1367ea commit a4fe177
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -439,21 +439,21 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.option("format", "deprecated").save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
}

test("native json model file should store feature_name and feature_type") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,21 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))

// test default "deprecated"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.option("format", "deprecated").save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
package ml.dmlc.xgboost4j.java;

import java.io.*;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -33,8 +35,8 @@
/**
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable, KryoSerializable {
public static final String DEFAULT_FORMAT = "deprecated";
public class Booster implements Serializable, KryoSerializable, AutoCloseable {
public static final String DEFAULT_FORMAT = "ubj";
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
Expand Down Expand Up @@ -815,6 +817,11 @@ protected void finalize() throws Throwable {
dispose();
}

@Override
public void close() throws XGBoostError {
dispose();
}

public synchronized void dispose() {
if (handle != 0L) {
XGBoostJNI.XGBoosterFree(handle);
Expand Down
Loading

0 comments on commit a4fe177

Please sign in to comment.