Skip to content

Commit

Permalink
feat: restore Ray Tuner to complete unfinished HPO run
Browse files Browse the repository at this point in the history
Also handle OOM errors during HPO trials.
  • Loading branch information
erwulff committed Feb 8, 2024
1 parent 08f0572 commit 6a517b7
Showing 1 changed file with 56 additions and 22 deletions.
78 changes: 56 additions & 22 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,10 +939,30 @@ def run_ray_training(config, args, outdir):


def set_searchspace_and_run_trial(search_space, config, args):
import ray
from raytune.pt_search_space import set_hps_from_search_space

rank = ray.train.get_context().get_local_rank()

config = set_hps_from_search_space(search_space, config)
train_ray_trial(config, args, outdir=None) # outdir will be taken from the TrainContext in each trial
try:
# outdir will be taken from the ray.train.context.TrainContext in each trial
train_ray_trial(config, args, outdir=None)
except torch.cuda.OutOfMemoryError:
ray.train.report({"val_loss": np.NAN})
torch.cuda.empty_cache() # make sure GPU memory is cleared for next trial
if rank == 0:
logging.warning("OOM error encountered, skipping this hyperparameter configuration.")
skiplog_file_path = Path(config["raytune"]["local_dir"]) / args.hpo / "skipped_configurations.txt"
lines = ["{}: {}\n".format(item[0], item[1]) for item in search_space.items()]

with open(skiplog_file_path, "a") as f:
f.write("#" * 80 + "\n")
for line in lines:
f.write(line)
logging.warning(line[:-1])
f.write("#" * 80 + "\n\n")
logging.warning("Done writing warnings to log.")


def run_hpo(config, args):
Expand Down Expand Up @@ -978,11 +998,13 @@ def run_hpo(config, args):
yaml.dump(config, file)

if not args.local:
_logger.info("Inititalizing ray...")
ray.init(
address=os.environ["ip_head"],
_node_ip_address=os.environ["head_node_ip"],
_temp_dir="/mnt/ceph/users/ewulff/tmp_ray",
# _temp_dir="/p/project/raise-ctp2/cern/tmp_ray",
)
_logger.info("Done.")

sched = get_raytune_schedule(config["raytune"])
search_alg = get_raytune_search_alg(config["raytune"])
Expand All @@ -992,30 +1014,42 @@ def run_hpo(config, args):
use_gpu=True,
resources_per_worker={"CPU": args.ray_cpus // (args.gpus) - 1, "GPU": 1}, # -1 to avoid blocking
)

if tune.Tuner.can_restore(str(expdir)):
args.resume_training = True

trainable = tune.with_parameters(set_searchspace_and_run_trial, config=config, args=args)
trainer = TorchTrainer(train_loop_per_worker=trainable, scaling_config=scaling_config)

search_space = {"train_loop_config": search_space} # the ray TorchTrainer only takes a single arg: train_loop_config
tuner = tune.Tuner(
trainer,
param_space=search_space,
tune_config=tune.TuneConfig(
num_samples=raytune_num_samples,
metric=config["raytune"]["default_metric"] if (search_alg is None and sched is None) else None,
mode=config["raytune"]["default_mode"] if (search_alg is None and sched is None) else None,
search_alg=search_alg,
scheduler=sched,
),
run_config=ray.train.RunConfig(
name=name,
storage_path=config["raytune"]["local_dir"],
log_to_file=False,
failure_config=ray.train.FailureConfig(max_failures=2),
checkpoint_config=ray.train.CheckpointConfig(num_to_keep=1), # keep only latest checkpoint
sync_config=ray.train.SyncConfig(sync_artifacts=True),
),
)
if tune.Tuner.can_restore(str(expdir)):
# resume unfinished HPO run
tuner = tune.Tuner.restore(
str(expdir), trainable=trainer, resume_errored=True, restart_errored=False, resume_unfinished=True
)
else:
# start new HPO run
search_space = {"train_loop_config": search_space} # the ray TorchTrainer only takes a single arg: train_loop_config
tuner = tune.Tuner(
trainer,
param_space=search_space,
tune_config=tune.TuneConfig(
num_samples=raytune_num_samples,
metric=config["raytune"]["default_metric"] if (search_alg is None and sched is None) else None,
mode=config["raytune"]["default_mode"] if (search_alg is None and sched is None) else None,
search_alg=search_alg,
scheduler=sched,
),
run_config=ray.train.RunConfig(
name=name,
storage_path=config["raytune"]["local_dir"],
log_to_file=False,
failure_config=ray.train.FailureConfig(max_failures=2),
checkpoint_config=ray.train.CheckpointConfig(num_to_keep=1), # keep only latest checkpoint
sync_config=ray.train.SyncConfig(sync_artifacts=True),
),
)
start = datetime.now()
_logger.info("Starting tuner.fit()")
result_grid = tuner.fit()
end = datetime.now()

Expand Down

0 comments on commit 6a517b7

Please sign in to comment.