From 370f47f939741ca4cc87da00d7f62a396ebdf57f Mon Sep 17 00:00:00 2001 From: Eric Wulff <31319227+erwulff@users.noreply.github.com> Date: Wed, 25 Oct 2023 13:58:45 +0200 Subject: [PATCH] Implement HPO for PyTorch pipeline. (#246) * wip: implement HPO in pytorch pipeline * fix: bugs after rebase * chore: code formatting * fix: minor bug * fix: typo * fix: lr casted to str when read from config * try reducing --ntrain --ntest in tests * update distbarrier and fix stale pochs (#249) * change pytorch CI/CD test to use gravnet model * feat: implemented HPO using Ray Tune Now able to perform hyperparameter search using random search with automatic trial launching and Ray-compatbile checkpointing. Support is still missing for: - Trial schedulers - Advanced Ray Tune search algorithms * fix: flake8 error * chore: update default config values for pyg --------- Co-authored-by: Farouk Mokhtar --- mlpf/pyg/mlpf.py | 2 +- mlpf/pyg/training.py | 58 ++++- mlpf/pyg_pipeline.py | 225 +++++++++++++++----- mlpf/raytune/pt_search_space.py | 48 +++++ mlpf/raytune/utils.py | 12 +- parameters/pyg-clic.yaml | 38 +++- parameters/pyg-cms-small.yaml | 18 +- parameters/pyg-cms-test-qcdhighpt.yaml | 18 +- parameters/pyg-cms.yaml | 38 +++- parameters/pyg-delphes.yaml | 18 +- parameters/pyg-workflow-test.yaml | 18 +- scripts/flatiron/pt_raytune_1GPUperTrial.sh | 88 ++++++++ scripts/flatiron/pt_train_4GPUs.slurm | 46 ++++ scripts/local_test_pyg.sh | 2 +- 14 files changed, 542 insertions(+), 87 deletions(-) create mode 100644 mlpf/raytune/pt_search_space.py create mode 100755 scripts/flatiron/pt_raytune_1GPUperTrial.sh create mode 100644 scripts/flatiron/pt_train_4GPUs.slurm diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index c17557262..a5a813f78 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -98,7 +98,7 @@ def __init__( for i in range(num_convs): self.conv_id.append(SelfAttentionLayer(embedding_dim)) self.conv_reg.append(SelfAttentionLayer(embedding_dim)) - elif self.conv_type == "gnn-lsh": + elif self.conv_type == "gnn_lsh": self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() for i in range(num_convs): diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index 3f924b1a0..543686a40 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -1,6 +1,8 @@ import pickle as pkl +from tempfile import TemporaryDirectory import time from typing import Optional +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -137,7 +139,7 @@ def train( best_val_loss, stale_epochs, patience, - outpath, + outdir, tensorboard_writer=None, ): """ @@ -238,10 +240,10 @@ def train( torch.save( {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, - f"{outpath}/best_weights.pth", + f"{outdir}/best_weights.pth", ) _logger.info( - f"finished {itrain+1}/{len(train_loader)} iterations and saved the model at {outpath}/best_weights.pth" # noqa + f"finished {itrain+1}/{len(train_loader)} iterations and saved the model at {outdir}/best_weights.pth" # noqa ) stale_epochs = torch.tensor(0, device=rank) else: @@ -278,7 +280,7 @@ def train( return epoch_loss, valid_loss, best_val_loss, stale_epochs -def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outpath): +def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outdir, hpo=False): """ Will run a full training by calling train(). @@ -288,11 +290,11 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n train_loader: a pytorch geometric Dataloader that loads the training data in the form ~ DataBatch(X, ygen, ycands) valid_loader: a pytorch geometric Dataloader that loads the validation data in the form ~ DataBatch(X, ygen, ycands) patience: number of stale epochs before stopping the training - outpath: path to store the model weights and training plots + outdir: path to store the model weights and training plots """ if (rank == 0) or (rank == "cpu"): - tensorboard_writer = SummaryWriter(f"{outpath}/runs/") + tensorboard_writer = SummaryWriter(f"{outdir}/runs/") else: tensorboard_writer = False @@ -306,7 +308,23 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n losses["train"][loss], losses["valid"][loss] = [], [] stale_epochs, best_val_loss = torch.tensor(0, device=rank), 99999.9 - for epoch in range(num_epochs): + start_epoch = 0 + + if hpo: + import ray.train as ray_train + from ray.train import Checkpoint + + checkpoint = ray_train.get_checkpoint() + if checkpoint: + with checkpoint.as_directory() as checkpoint_dir: + with checkpoint.as_directory() as checkpoint_dir: + checkpoint_dir = Path(checkpoint_dir) + # TODO: EW, check if map_location should be "cpu" below + model.load_state_dict(torch.load(checkpoint_dir / "model.pt")) + optimizer.load_state_dict(torch.load(checkpoint_dir / "optim.pt")) + start_epoch = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1 + + for epoch in range(start_epoch, num_epochs): _logger.info(f"Initiating epoch # {epoch}", color="bold") t0 = time.time() @@ -321,10 +339,30 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n best_val_loss, stale_epochs, patience, - outpath, + outdir, tensorboard_writer, ) + if hpo: + # save model, optimizer and epoch number for HPO-supported checkpointing + if (rank == 0) or (rank == "cpu"): + # Ray automatically syncs the cehckpoint to persistent storage + with TemporaryDirectory() as temp_checkpoint_dir: + temp_checkpoint_dir = Path(temp_checkpoint_dir) + torch.save(model.state_dict(), temp_checkpoint_dir / "model.pt") + torch.save(optimizer.state_dict(), temp_checkpoint_dir / "optim.pt") + torch.save({"epoch": epoch}, temp_checkpoint_dir / "extra_state.pt") + + # report metrics and checkpoint to Ray + ray_train.report( + dict( + loss=losses_t["Total"], + val_loss=losses_v["Total"], + epoch=epoch, + ), + checkpoint=Checkpoint.from_directory(temp_checkpoint_dir), + ) + if stale_epochs > patience: break @@ -378,10 +416,10 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n ax.set_ylim(0.8 * losses["train"][loss][-1], 1.2 * losses["train"][loss][-1]) ax.legend(title="MLPF", loc="best", title_fontsize=20, fontsize=15) plt.tight_layout() - plt.savefig(f"{outpath}/mlpf_loss_{loss}.pdf") + plt.savefig(f"{outdir}/mlpf_loss_{loss}.pdf") plt.close() - with open(f"{outpath}/mlpf_losses.pkl", "wb") as f: + with open(f"{outdir}/mlpf_losses.pkl", "wb") as f: pkl.dump(losses, f) _logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min") diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index a50b72b41..6732728a0 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -5,11 +5,14 @@ """ import argparse +from datetime import datetime +from functools import partial import logging import os import os.path as osp import pickle as pkl from pathlib import Path +import shutil import yaml import fastjet import torch @@ -23,35 +26,44 @@ from pyg.utils import CLASS_LABELS, X_FEATURES, save_HPs from utils import create_experiment_dir + logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default="parameters/pyg-cms.yaml", help="yaml config") -parser.add_argument("--prefix", type=str, default="test_", help="prefix appended to result dir name") -parser.add_argument("--data_dir", type=str, default="/pfvol/tensorflow_datasets/", help="path to `tensorflow_datasets/`") -parser.add_argument("--gpus", type=str, default="0", help="to use CPU set to empty string; else e.g., `0,1`") +# add default=None to all arparse arguments to ensure they do not override +# values loaded from the config file given by --config unless explicitly given +parser.add_argument("--config", type=str, default=None, help="yaml config") +parser.add_argument("--prefix", type=str, default=None, help="prefix appended to result dir name") +parser.add_argument("--data-dir", type=str, default=None, help="path to `tensorflow_datasets/`") +parser.add_argument("--gpus", type=str, default=None, help="to use CPU set to empty string; else e.g., `0,1`") parser.add_argument( - "--gpu-batch-multiplier", type=int, default=1, help="increase batch size per GPU by this constant factor" + "--gpu-batch-multiplier", type=int, default=None, help="Increase batch size per GPU by this constant factor" ) -parser.add_argument("--dataset", type=str, choices=["clic", "cms", "delphes"], required=True, help="which dataset?") -parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") -parser.add_argument("--ntest", type=int, default=None, help="testing samples to use, if None use entire dataset") -parser.add_argument("--nvalid", type=int, default=500, help="validation samples to use, default will use 500 events") -parser.add_argument("--num-workers", type=int, default=0, help="number of processes to load the data") +parser.add_argument( + "--dataset", type=str, default=None, choices=["clic", "cms", "delphes"], required=False, help="which dataset?" +) +parser.add_argument("--num-workers", type=int, default=None, help="number of processes to load the data") parser.add_argument("--prefetch-factor", type=int, default=None, help="number of samples to fetch & prefetch at every call") parser.add_argument("--load", type=str, default=None, help="dir from which to load a saved model") -parser.add_argument("--train", action="store_true", help="initiates a training") -parser.add_argument("--test", action="store_true", help="tests the model") -parser.add_argument("--num-epochs", type=int, default=3, help="number of training epochs") -parser.add_argument("--patience", type=int, default=20, help="patience before early stopping") -parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") -parser.add_argument("--conv-type", type=str, default="gravnet", help="choices are ['gnn-lsh', 'gravnet', 'attention']") -parser.add_argument("--make-plots", action="store_true", help="make plots of the test predictions") -parser.add_argument("--export-onnx", action="store_true", help="exports the model to onnx") +parser.add_argument("--train", action="store_true", default=None, help="initiates a training") +parser.add_argument("--test", action="store_true", default=None, help="tests the model") +parser.add_argument("--num-epochs", type=int, default=None, help="number of training epochs") +parser.add_argument("--patience", type=int, default=None, help="patience before early stopping") +parser.add_argument("--lr", type=float, default=None, help="learning rate") +parser.add_argument("--conv-type", type=str, default=None, help="choices are ['gnn_lsh', 'gravnet', 'attention']") +parser.add_argument("--make-plots", action="store_true", default=None, help="make plots of the test predictions") +parser.add_argument("--export-onnx", action="store_true", default=None, help="exports the model to onnx") +parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") +parser.add_argument("--ntest", type=int, default=None, help="training samples to use, if None use entire dataset") +parser.add_argument("--nvalid", type=int, default=500, help="validation samples to use, default will use 500 events") +parser.add_argument("--hpo", type=str, default=None, help="perform hyperparameter optimization, name of HPO experiment") +parser.add_argument("--local", action="store_true", default=None, help="perform HPO locally, without a Ray cluster") +parser.add_argument("--ray-cpus", type=int, default=None, help="CPUs per trial for HPO") +parser.add_argument("--ray-gpus", type=int, default=None, help="GPUs per trial for HPO") -def run(rank, world_size, args, outdir, logfile): +def run(rank, world_size, config, args, outdir, logfile): """Demo function that will be passed to each gpu if (world_size > 1) else will run normally on the given device.""" if world_size > 1: @@ -62,11 +74,9 @@ def run(rank, world_size, args, outdir, logfile): if (rank == 0) or (rank == "cpu"): # keep writing the logs _configLogger("mlpf", filename=logfile) - with open(args.config, "r") as stream: # load config (includes: which physics samples, model params) - config = yaml.safe_load(stream) + if config["load"]: # load a pre-trained model + outdir = config["load"] # in case both --load and --train are provided - if args.load: # load a pre-trained model - outdir = args.load # in case both --load and --train are provided with open(f"{outdir}/model_kwargs.pkl", "rb") as f: model_kwargs = pkl.load(f) @@ -86,12 +96,12 @@ def run(rank, world_size, args, outdir, logfile): else: # instantiate a new model in the outdir created model_kwargs = { - "input_dim": len(X_FEATURES[args.dataset]), - "num_classes": len(CLASS_LABELS[args.dataset]), - **config["model"][args.conv_type], + "input_dim": len(X_FEATURES[config["dataset"]]), + "num_classes": len(CLASS_LABELS[config["dataset"]]), + **config["model"][config["conv_type"]], } model = MLPF(**model_kwargs) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"]) model.to(rank) @@ -105,31 +115,35 @@ def run(rank, world_size, args, outdir, logfile): if args.train: if (rank == 0) or (rank == "cpu"): save_HPs(args, model, model_kwargs, outdir) # save model_kwargs and hyperparameters - _logger.info(f"Creating experiment dir {outdir}") + _logger.info("Creating experiment dir {}".format(outdir)) _logger.info(f"Model directory {outdir}", color="bold") train_loaders = [] - for sample in config["train_dataset"][args.dataset]: - version = config["train_dataset"][args.dataset][sample]["version"] - batch_size = config["train_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier + for sample in config["train_dataset"][config["dataset"]]: + version = config["train_dataset"][config["dataset"]][sample]["version"] + batch_size = config["train_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] - ds = PFDataset(args.data_dir, f"{sample}:{version}", "train", ["X", "ygen"], num_samples=args.ntrain) + ds = PFDataset(config["data_dir"], f"{sample}:{version}", "train", ["X", "ygen"], num_samples=config["ntrain"]) _logger.info(f"train_dataset: {ds}, {len(ds)}", color="blue") - train_loaders.append(ds.get_loader(batch_size, world_size, args.num_workers, args.prefetch_factor)) + train_loaders.append(ds.get_loader(batch_size, world_size, config["num_workers"], config["prefetch_factor"])) train_loader = InterleavedIterator(train_loaders) if (rank == 0) or (rank == "cpu"): # quick validation only on a single machine valid_loaders = [] - for sample in config["valid_dataset"][args.dataset]: - version = config["valid_dataset"][args.dataset][sample]["version"] - batch_size = config["valid_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier + for sample in config["valid_dataset"][config["dataset"]]: + version = config["valid_dataset"][config["dataset"]][sample]["version"] + batch_size = ( + config["valid_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] + ) - ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=args.nvalid) + ds = PFDataset( + config["data_dir"], f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=config["nvalid"] + ) _logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue") - valid_loaders.append(ds.get_loader(batch_size, 1, args.num_workers, args.prefetch_factor)) + valid_loaders.append(ds.get_loader(batch_size, 1, config["num_workers"], config["prefetch_factor"])) valid_loader = InterleavedIterator(valid_loaders) else: @@ -142,29 +156,33 @@ def run(rank, world_size, args, outdir, logfile): optimizer, train_loader, valid_loader, - args.num_epochs, - args.patience, + config["num_epochs"], + config["patience"], outdir, + hpo=True if args.hpo is not None else False, ) if args.test: - if args.load is None: + + if config["load"] is None: # if we don't load, we must have a newly trained model assert args.train, "Please train a model before testing, or load a model with --load" assert outdir is not None, "Error: no outdir to evaluate model from" else: - outdir = args.load + outdir = config["load"] test_loaders = {} - for sample in config["test_dataset"][args.dataset]: - version = config["test_dataset"][args.dataset][sample]["version"] - batch_size = config["test_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier + for sample in config["test_dataset"][config["dataset"]]: + version = config["test_dataset"][config["dataset"]][sample]["version"] + batch_size = config["test_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] - ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], args.ntest) + ds = PFDataset( + config["data_dir"], f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=config["ntest"] + ) _logger.info(f"test_dataset: {ds}, {len(ds)}", color="blue") test_loaders[sample] = InterleavedIterator( - [ds.get_loader(batch_size, world_size, args.num_workers, args.prefetch_factor)] + [ds.get_loader(batch_size, world_size, config["num_workers"], config["prefetch_factor"])] ) if not osp.isdir(f"{outdir}/preds/{sample}"): @@ -192,10 +210,10 @@ def run(rank, world_size, args, outdir, logfile): if (rank == 0) or (rank == "cpu"): # make plots and export to onnx only on a single machine if args.make_plots: - for sample in config["test_dataset"][args.dataset]: + for sample in config["test_dataset"][config["dataset"]]: _logger.info(f"Plotting distributions for {sample}") - make_plots(outdir, sample, args.dataset) + make_plots(outdir, sample, config["dataset"]) if args.export_onnx: try: @@ -223,12 +241,18 @@ def run(rank, world_size, args, outdir, logfile): dist.destroy_process_group() -def main(): - args = parser.parse_args() - world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0") +def override_config(config, args): + """override config with values from argparse Namespace""" + for arg in vars(args): + arg_value = getattr(args, arg) + if arg_value is not None: + config[arg] = arg_value + return config + + +def device_agnostic_run(config, args, world_size, outdir): if args.train: # create a new outdir when training a model to never overwrite - outdir = create_experiment_dir(prefix=args.prefix + Path(args.config).stem + "_") logfile = f"{outdir}/train.log" _configLogger("mlpf", filename=logfile) @@ -240,7 +264,7 @@ def main(): os.system(f"cp {args.config} {outdir}/test-config.yaml") - if args.gpus: + if config["gpus"]: assert ( world_size <= torch.cuda.device_count() ), f"--gpus is too high (specefied {world_size} gpus but only {torch.cuda.device_count()} gpus are available)" @@ -253,19 +277,106 @@ def main(): mp.spawn( run, - args=(world_size, args, outdir, logfile), + args=(world_size, config, args, outdir, logfile), nprocs=world_size, join=True, ) elif world_size == 1: rank = 0 _logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple") - run(rank, world_size, args, outdir, logfile) + run(rank, world_size, config, args, outdir, logfile) else: rank = "cpu" _logger.info("Will use cpu", color="purple") - run(rank, world_size, args, outdir, logfile) + run(rank, world_size, config, args, outdir, logfile) + + +def main(): + args = parser.parse_args() + world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0") + + with open(args.config, "r") as stream: # load config (includes: which physics samples, model params) + config = yaml.safe_load(stream) + + # override loaded config with values from command line args + config = override_config(config, args) + + if args.hpo: + import ray + from ray import tune + from ray import train as ray_train + + # from ray.tune.logger import TBXLoggerCallback + from raytune.pt_search_space import raytune_num_samples, search_space, set_hps_from_search_space + from raytune.utils import get_raytune_schedule, get_raytune_search_alg + + name = args.hpo # name of Ray Tune experiment directory + + os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1" # don't crash if a metric is missing + if isinstance(config["raytune"]["local_dir"], type(None)): + raise TypeError("Please specify a local_dir in the raytune section of the config file.") + trd = config["raytune"]["local_dir"] + "/tune_result_dir" + os.environ["TUNE_RESULT_DIR"] = trd + + expdir = Path(config["raytune"]["local_dir"]) / name + expdir.mkdir(parents=True, exist_ok=True) + shutil.copy( + "mlpf/raytune/search_space.py", + str(Path(config["raytune"]["local_dir"]) / name / "search_space.py"), + ) # Copy the config file to the train dir for later reference + shutil.copy( + args.config, + str(Path(config["raytune"]["local_dir"]) / name / "config.yaml"), + ) # Copy the config file to the train dir for later reference + + if not args.local: + ray.init(address="auto") + + sched = get_raytune_schedule(config["raytune"]) + search_alg = get_raytune_search_alg(config["raytune"]) + + def hpo(search_space, config, args, world_size): + config = set_hps_from_search_space(search_space, config) + outdir = ray_train.get_context().get_trial_dir() + device_agnostic_run(config, args, world_size, outdir) + + start = datetime.now() + analysis = tune.run( + partial( + hpo, + config=config, + args=args, + world_size=world_size, + ), + config=search_space, + resources_per_trial={"cpu": args.ray_cpus, "gpu": args.ray_gpus}, + name=name, + scheduler=sched, + search_alg=search_alg, + num_samples=raytune_num_samples, + local_dir=config["raytune"]["local_dir"], + # callbacks=[TBXLoggerCallback()], + log_to_file=True, + resume=False, # TODO: make this configurable + max_failures=2, + # sync_config=sync_config, + ) + end = datetime.now() + logging.info("Total time of tune.run(...): {}".format(end - start)) + + logging.info( + "Best hyperparameters found according to {} were: ".format(config["raytune"]["default_metric"]), + analysis.get_best_config( + metric=config["raytune"]["default_metric"], + mode=config["raytune"]["default_mode"], + scope="all", + ), + ) + + else: + outdir = create_experiment_dir(prefix=(args.prefix or "") + Path(args.config).stem + "_") + device_agnostic_run(config, args, world_size, outdir) if __name__ == "__main__": diff --git a/mlpf/raytune/pt_search_space.py b/mlpf/raytune/pt_search_space.py new file mode 100644 index 000000000..e8076c491 --- /dev/null +++ b/mlpf/raytune/pt_search_space.py @@ -0,0 +1,48 @@ +from ray.tune import choice # grid_search, choice, loguniform, quniform + +raytune_num_samples = 8 # Number of random samples to draw from search space. Set to 1 for grid search. +samp = choice + +# gnn scan +search_space = { + # optimizer parameters + "lr": samp([1e-4, 1e-3, 1e-2]), + # "gpu_batch_multiplier": samp([10, 20, 40]), + # model arch parameters + "conv_type": samp(["gnn_lsh"]), + "embedding_dim": samp([128, 252, 512]), + # "width": samp([512]), + # "num_convs": samp([3]), + # "dropout": samp([0.0]), + # "patience": samp([20]) +} + + +def set_hps_from_search_space(search_space, config): + if "lr" in search_space.keys(): + config["lr"] = search_space["lr"] + + if "gpu_batch_multiplier" in search_space.keys(): + config["gpu_batch_multiplier"] = search_space["gpu_batch_multiplier"] + + if "conv_type" in search_space.keys(): + conv_type = search_space["conv_type"] + config["conv_type"] = conv_type + + if conv_type == "gnn_lsh" or conv_type == "transformer": + if "embedding_dim" in search_space.keys(): + config["model"][conv_type]["embedding_dim"] = search_space["embedding_dim"] + + if "width" in search_space.keys(): + config["model"][conv_type]["width"] = search_space["width"] + + if "num_convs" in search_space.keys(): + config["model"][conv_type]["num_convs"] = search_space["num_convs"] + + if "num_convs" in search_space.keys(): + config["model"][conv_type]["num_convs"] = search_space["num_convs"] + + if "embedding_dim" in search_space.keys(): + config["embedding_dim"] = search_space["embedding_dim"] + + return config diff --git a/mlpf/raytune/utils.py b/mlpf/raytune/utils.py index b71978676..8fc44443e 100644 --- a/mlpf/raytune/utils.py +++ b/mlpf/raytune/utils.py @@ -5,13 +5,13 @@ PopulationBasedTraining, ) from ray.tune.schedulers.pb2 import PB2 # Population Based Bandits -from ray.tune.suggest.bayesopt import BayesOptSearch -from ray.tune.suggest.bohb import TuneBOHB -from ray.tune.suggest.hyperopt import HyperOptSearch -from ray.tune.suggest.nevergrad import NevergradSearch -from ray.tune.suggest.skopt import SkOptSearch +from ray.tune.search.bayesopt import BayesOptSearch +from ray.tune.search.bohb import TuneBOHB +from ray.tune.search.hyperopt import HyperOptSearch +from ray.tune.search.nevergrad import NevergradSearch +from ray.tune.search.skopt import SkOptSearch -# from ray.tune.suggest.hebo import HEBOSearch # HEBO is not yet supported +# from ray.tune.search.hebo import HEBOSearch # HEBO is not yet supported def get_raytune_search_alg(raytune_cfg, seeds=False): diff --git a/parameters/pyg-clic.yaml b/parameters/pyg-clic.yaml index 642a911fa..3bba81e05 100644 --- a/parameters/pyg-clic.yaml +++ b/parameters/pyg-clic.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: clic +data_dir: /mnt/ceph/users/ewulff/tensorflow_datasets/clusters +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 @@ -25,6 +39,26 @@ model: num_convs: 2 dropout: 0.0 +raytune: + local_dir: /mnt/ceph/users/ewulff/ray_results/ # Note: please specify an absolute path + sched: # asha, hyperband + search_alg: # bayes, bohb, hyperopt, nevergrad, scikit + default_metric: "val_loss" + default_mode: "min" + # Tune schedule specific parameters + asha: + max_t: 200 + reduction_factor: 4 + brackets: 1 + grace_period: 10 + hyperband: + max_t: 200 + reduction_factor: 4 + hyperopt: + n_random_steps: 10 + nevergrad: + n_random_steps: 10 + train_dataset: clic: clic_edm_qq_pf: diff --git a/parameters/pyg-cms-small.yaml b/parameters/pyg-cms-small.yaml index 56e8773af..6d2b9b4ee 100644 --- a/parameters/pyg-cms-small.yaml +++ b/parameters/pyg-cms-small.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: cms +data_dir: +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 diff --git a/parameters/pyg-cms-test-qcdhighpt.yaml b/parameters/pyg-cms-test-qcdhighpt.yaml index dad102a0b..4c5535f54 100644 --- a/parameters/pyg-cms-test-qcdhighpt.yaml +++ b/parameters/pyg-cms-test-qcdhighpt.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: cms +data_dir: +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 diff --git a/parameters/pyg-cms.yaml b/parameters/pyg-cms.yaml index 4cc2ff30e..cff86ebcb 100644 --- a/parameters/pyg-cms.yaml +++ b/parameters/pyg-cms.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: cms +data_dir: /mnt/ceph/users/ewulff/tensorflow_datasets/ +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 @@ -25,6 +39,26 @@ model: num_convs: 2 dropout: 0.0 +raytune: + local_dir: /mnt/ceph/users/ewulff/ray_results/ # Note: please specify an absolute path + sched: # asha, hyperband + search_alg: # bayes, bohb, hyperopt, nevergrad, scikit + default_metric: "val_loss" + default_mode: "min" + # Tune schedule specific parameters + asha: + max_t: 200 + reduction_factor: 4 + brackets: 1 + grace_period: 10 + hyperband: + max_t: 200 + reduction_factor: 4 + hyperopt: + n_random_steps: 10 + nevergrad: + n_random_steps: 10 + train_dataset: cms: cms_pf_ttbar: diff --git a/parameters/pyg-delphes.yaml b/parameters/pyg-delphes.yaml index 42b09495b..74483481a 100644 --- a/parameters/pyg-delphes.yaml +++ b/parameters/pyg-delphes.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: delphes +data_dir: +gpus: "0" +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gnn_lsh +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 diff --git a/parameters/pyg-workflow-test.yaml b/parameters/pyg-workflow-test.yaml index eafc3cb3d..1ffc6a354 100644 --- a/parameters/pyg-workflow-test.yaml +++ b/parameters/pyg-workflow-test.yaml @@ -1,8 +1,22 @@ backend: pytorch +dataset: cms +data_dir: +gpus: +gpu_batch_multiplier: 1 +load: +num_epochs: 2 +patience: 20 +lr: 0.0001 +conv_type: gravnet +ntrain: +ntest: +num_workers: 0 +prefetch_factor: + model: - gnn-lsh: - conv_type: gnn-lsh + gnn_lsh: + conv_type: gnn_lsh embedding_dim: 512 width: 512 num_convs: 3 diff --git a/scripts/flatiron/pt_raytune_1GPUperTrial.sh b/scripts/flatiron/pt_raytune_1GPUperTrial.sh new file mode 100755 index 000000000..ebc3b567c --- /dev/null +++ b/scripts/flatiron/pt_raytune_1GPUperTrial.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +#SBATCH -t 2:00:00 +#SBATCH -N 2 +#SBATCH --tasks-per-node=1 +#SBATCH -p gpu +#SBATCH --constraint=a100,ib +#SBATCH --gpus-per-task=4 +#SBATCH --cpus-per-task=64 + +# Job name +#SBATCH -J raytune + +# Output and error logs +#SBATCH -o logs_slurm/log_%x_%j.out +#SBATCH -e logs_slurm/log_%x_%j.err + +# Add jobscript to job output +echo "#################### Job submission script. #############################" +cat $0 +echo "################# End of job submission script. #########################" + +set -x + +module --force purge; module load modules/2.2-20230808 +module load slurm gcc cmake cuda/12.1.1 cudnn/8.9.2.26-12.x nccl openmpi apptainer + +nvidia-smi +source ~/miniconda3/bin/activate pytorch +which python3 +python3 --version + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +num_gpus=4 + + +################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ############### +# This script is a modification to the implementation suggest by gregSchwartz18 here: +# https://github.com/ray-project/ray/issues/826#issuecomment-522116599 +redis_password=$(uuidgen) +export redis_password +echo "Redis password: ${redis_password}" + +nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) # Getting the node names +nodes_array=( $nodes ) + +node_1=${nodes_array[0]} +ip=$(srun --nodes=1 --ntasks=1 -w $node_1 hostname --ip-address) # making redis-address +port=6379 +ip_head=$ip:$port +export ip_head +echo "IP Head: $ip_head" + +echo "STARTING HEAD at $node_1" +srun --nodes=1 --ntasks=1 -w $node_1 \ + ray start --head --node-ip-address="$node_1" --port=$port \ + --num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block & # mlpf/raytune/start-head.sh $ip $port & + +sleep 10 + +worker_num=$(($SLURM_JOB_NUM_NODES - 1)) #number of nodes other than the head node +for (( i=1; i<=$worker_num; i++ )) +do + node_i=${nodes_array[$i]} + echo "STARTING WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w $node_i \ + ray start --address "$node_1":"$port" \ + --num-cpus $((SLURM_CPUS_PER_TASK)) --num-gpus $num_gpus --block & # mlpf/raytune/start-worker.sh $ip_head & + sleep 5 +done + +echo All Ray workers started. +############################################################################################## + +#### call your code below +# python3 mlpf/pipeline.py raytune -c $1 -n $2 --cpus $((SLURM_CPUS_PER_TASK/4)) \ +# --gpus 1 --seeds --comet-exp-name particleflow-raytune + +python3 -u mlpf/pyg_pipeline.py --train \ + --config $1 \ + --hpo $2 \ + --ray-cpus $((SLURM_CPUS_PER_TASK/4)) \ + --ray-gpus 1 \ + --gpus "0" \ + --ntrain 1000 \ + --ntest 1000 + +exit diff --git a/scripts/flatiron/pt_train_4GPUs.slurm b/scripts/flatiron/pt_train_4GPUs.slurm new file mode 100644 index 000000000..93efab405 --- /dev/null +++ b/scripts/flatiron/pt_train_4GPUs.slurm @@ -0,0 +1,46 @@ +#!/bin/sh + +# Walltime limit +#SBATCH -t 1:00:00 +#SBATCH -N 1 +#SBATCH --exclusive +#SBATCH --tasks-per-node=1 +#SBATCH -p gpu +#SBATCH --gpus-per-task=4 +#SBATCH --constraint=a100-80gb,ib + +# Job name +#SBATCH -J pt_train + +# Output and error logs +#SBATCH -o logs_slurm/log_%x_%j.out +#SBATCH -e logs_slurm/log_%x_%j.err + +# Add jobscript to job output +echo "#################### Job submission script. #############################" +cat $0 +echo "################# End of job submission script. #########################" + + +module --force purge; module load modules/2.2-20230808 +module load slurm gcc cmake cuda/12.1.1 cudnn/8.9.2.26-12.x nccl openmpi apptainer + +nvidia-smi +source ~/miniconda3/bin/activate pytorch +which python3 +python3 --version + + +echo 'Starting training.' +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u mlpf/pyg_pipeline.py --train \ + --config $1 \ + --prefix $2 \ + --gpus "0" \ + --num-epochs 3 \ + --lr 0.0001 \ + --conv-type gnn_lsh \ + --ntrain 400 \ + --ntest 400 \ + --num-workers 1 + +echo 'Training done.' diff --git a/scripts/local_test_pyg.sh b/scripts/local_test_pyg.sh index 7ba0e80ab..40d65ba04 100755 --- a/scripts/local_test_pyg.sh +++ b/scripts/local_test_pyg.sh @@ -27,4 +27,4 @@ mkdir -p experiments tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data -python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data_dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots +python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data-dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots