Skip to content

Commit

Permalink
redefine checkpoint policy.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 21, 2023
1 parent 762185e commit 90a3060
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "2.model")
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)

manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}

Expand All @@ -65,21 +65,20 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite

val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(4)
assert(new File(s"$tmpPath/4.model").exists())
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())

manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/4.model").exists())
assert(!new File(s"$tmpPath/3.model").exists())
}

test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(3))(manager.getCheckpointRounds(0, 3).asScala)
assertResult(Seq(0, 1, 2, 3))(manager.getCheckpointRounds(0, 3).asScala)
manager.updateCheckpoint(model2._booster)
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class ExternalCheckpointManager {

private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;

public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
Expand All @@ -50,6 +50,7 @@ private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
Expand All @@ -59,14 +60,19 @@ private List<Integer> getExistingVersions() throws IOException {
}
}

private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}

public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}

public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Expand All @@ -79,13 +85,16 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {

public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getNumBoostedRound());
.map(this::getPath).collect(Collectors.toList());
// checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getNumBoostedRound());
logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
Expand All @@ -97,35 +106,34 @@ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XG
}

public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}

public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException {
Integer end = firstRound + numOfRounds; // exclusive
Integer lastRound = end - 1;
if (end - 1 < 0) {
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
}

List<Integer> arr = new ArrayList<>();
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}

if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public static Booster train(
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}

// save checkpoint if iter is in checkpointIterations
private static void saveCheckpoint(
Booster booster,
int iter,
Expand Down Expand Up @@ -169,7 +169,6 @@ public static Booster trainAndSaveCheckpoint(
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
Set<Integer> checkpointIterations = new HashSet<>();
ExternalCheckpointManager ecm = null;
if (checkpointPath != null) {
ecm = new ExternalCheckpointManager(checkpointPath, fs);
Expand Down Expand Up @@ -208,8 +207,10 @@ public static Booster trainAndSaveCheckpoint(
booster.setParams(params);
}

Set<Integer> checkpointIterations = new HashSet<>();
if (ecm != null) {
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
checkpointIterations = new HashSet<>(
ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds));
}

boolean initial_best_score_flag = false;
Expand Down
27 changes: 18 additions & 9 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,10 @@ def after_training(self, model: _Model) -> _Model:


class TrainingCheckPoint(TrainingCallback):
"""Checkpointing operation.
"""Checkpointing operation. Users are encouraged to create their own callbacks for
checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on
distributed systems, be sure to know the rank of the worker to avoid multiple
workers checkpointing to the same place.
.. versionadded:: 1.3.0
Expand All @@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback):
pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json ....
as_pickle :
When set to True, all training parameters will be saved in pickle format, instead
of saving only the model.
iterations :
When set to True, all training parameters will be saved in pickle format,
instead of saving only the model.
interval :
Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit.
Expand All @@ -566,15 +569,20 @@ def __init__(
directory: Union[str, os.PathLike],
name: str = "model",
as_pickle: bool = False,
iterations: int = 100,
interval: int = 100,
) -> None:
self._path = os.fspath(directory)
self._name = name
self._as_pickle = as_pickle
self._iterations = iterations
self._epoch = 0
self._iterations = interval
self._epoch = 0 # counter for iterval
self._start = 0 # beginning iteration
super().__init__()

def before_training(self, model: _Model) -> _Model:
self._start = model.num_boosted_rounds()
return model

def after_iteration(
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
Expand All @@ -583,11 +591,12 @@ def after_iteration(
self._path,
self._name
+ "_"
+ str(epoch)
+ (str(epoch + self._start))
+ (".pkl" if self._as_pickle else ".json"),
)
self._epoch = 0
self._epoch = 0 # reset counter
if collective.get_rank() == 0:
# checkpoint using the first worker
if self._as_pickle:
with open(path, "wb") as fd:
pickle.dump(model, fd)
Expand Down

0 comments on commit 90a3060

Please sign in to comment.