diff --git a/mlpf/pyg/inference.py b/mlpf/pyg/inference.py index effd79a2c..86bfe881f 100644 --- a/mlpf/pyg/inference.py +++ b/mlpf/pyg/inference.py @@ -24,6 +24,8 @@ from .logger import _logger from .utils import CLASS_NAMES, unpack_predictions, unpack_target +import torch_geometric +from torch_geometric.data import Batch @torch.no_grad() @@ -34,9 +36,18 @@ def run_predictions(rank, model, loader, sample, outpath, jetdef, jet_ptcut=5.0, ti = time.time() for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)): + + if model.conv_type != "gravnet": + X_pad, mask = torch_geometric.utils.to_dense_batch(batch.X, batch.batch) + batch_pad = Batch(X=X_pad, mask=mask) + ypred = model(batch_pad.to(rank)) + ypred = ypred[0][mask], ypred[1][mask], ypred[2][mask] + else: + ypred = model(batch.to(rank)) + ygen = unpack_target(batch.ygen) ycand = unpack_target(batch.ycand) - ypred = unpack_predictions(model(batch.to(rank))) + ypred = unpack_predictions(ypred) for k, v in ypred.items(): ypred[k] = v.detach().cpu() diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index caefbec0e..1d5de74f5 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -74,10 +74,11 @@ def __init__( self.input_dim = input_dim self.num_convs = num_convs + self.bin_size = 640 + # embedding of the inputs if num_convs != 0: self.nn0 = ffn(input_dim, embedding_dim, width, self.act, dropout) - self.bin_size = 640 if self.conv_type == "gravnet": self.conv_id = nn.ModuleList() self.conv_reg = nn.ModuleList() diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index 7827cf68a..65a447968 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -136,6 +136,7 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch. """ + train_or_valid = "train" if is_train else "valid" _logger.info(f"Initiating a train={is_train} run on device rank={rank}", color="red") # this one will keep accumulating `train_loss` and then return the average @@ -146,7 +147,13 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): else: model.eval() - for itrain, batch in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)): + for itrain, batch in tqdm.tqdm( + enumerate(data_loader), total=len(data_loader), desc=f"{train_or_valid} loop on rank={rank}" + ): + + if world_size > 1: + _logger.info(f"Step {itrain} on rank={rank}") + batch = batch.to(rank, non_blocking=True) ygen = unpack_target(batch.ygen) @@ -154,10 +161,7 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): if is_train: ypred = model(batch) else: - if world_size > 1: # validation is only run on a single machine - ypred = model.module(batch) - else: - ypred = model(batch) + ypred = model(batch) ypred = unpack_predictions(ypred) @@ -176,11 +180,19 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train): for loss_ in epoch_loss: epoch_loss[loss_] += loss[loss_].detach() + num_data = torch.tensor(len(data_loader), device=rank) + # sum up the number of steps from all workers if world_size > 1: - dist.barrier() + torch.distributed.all_reduce(num_data) for loss_ in epoch_loss: - epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / len(data_loader) + # sum up the losses from all workers + if world_size > 1: + torch.distributed.all_reduce(epoch_loss[loss_]) + epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / num_data.cpu().item() + + if world_size > 1: + dist.barrier() return epoch_loss @@ -201,7 +213,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n if (rank == 0) or (rank == "cpu"): tensorboard_writer = SummaryWriter(f"{outdir}/runs/") else: - tensorboard_writer = False + tensorboard_writer = None t0_initial = time.time() @@ -230,7 +242,8 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n start_epoch = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1 for epoch in range(start_epoch, num_epochs): - _logger.info(f"Initiating epoch # {epoch}", color="bold") + if (rank == 0) or (rank == "cpu"): + _logger.info(f"Initiating epoch # {epoch}", color="bold") t0 = time.time() # training step @@ -244,26 +257,26 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n else: losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True) + losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False) + if (rank == 0) or (rank == "cpu"): - losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False) if losses_v["Total"] < best_val_loss: best_val_loss = losses_v["Total"] stale_epochs = 0 + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_state_dict = model.module.state_dict() + else: + model_state_dict = model.state_dict() + + torch.save( + {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, + # "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format( + # outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]), + f"{outdir}/best_weights.pth", + ) else: stale_epochs += 1 - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model_state_dict = model.module.state_dict() - else: - model_state_dict = model.state_dict() - - torch.save( - {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, - # "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format( - # outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]), - f"{outdir}/best_weights.pth", - ) - if hpo: # save model, optimizer and epoch number for HPO-supported checkpointing if (rank == 0) or (rank == "cpu"): @@ -287,9 +300,6 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n if stale_epochs > patience: break - for k, v in losses_t.items(): - tensorboard_writer.add_scalar(f"epoch/train_loss_rank_{rank}_" + k, v, epoch) - if (rank == 0) or (rank == "cpu"): for k, v in losses_t.items(): tensorboard_writer.add_scalar("epoch/train_loss_" + k, v, epoch) @@ -297,6 +307,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n for loss in losses_of_interest: losses["train"][loss].append(losses_t[loss]) losses["valid"][loss].append(losses_v[loss]) + for k, v in losses_v.items(): tensorboard_writer.add_scalar("epoch/valid_loss_" + k, v, epoch) @@ -340,6 +351,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n with open(f"{outdir}/mlpf_losses.pkl", "wb") as f: pkl.dump(losses, f) - tensorboard_writer.flush() + if tensorboard_writer: + tensorboard_writer.flush() _logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min") diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index 317566d78..6461b954c 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -86,7 +86,7 @@ def run(rank, world_size, config, args, outdir, logfile): model_kwargs = pkl.load(f) model = MLPF(**model_kwargs) - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"]) checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) @@ -151,38 +151,33 @@ def run(rank, world_size, config, args, outdir, logfile): train_loader = InterleavedIterator(train_loaders) - if (rank == 0) or (rank == "cpu"): # quick validation only on a single machine - valid_loaders = [] - for sample in config["valid_dataset"][config["dataset"]]: - version = config["valid_dataset"][config["dataset"]][sample]["version"] - batch_size = ( - config["valid_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] - ) + valid_loaders = [] + for sample in config["valid_dataset"][config["dataset"]]: + version = config["valid_dataset"][config["dataset"]][sample]["version"] + batch_size = config["valid_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"] - ds = PFDataset( - config["data_dir"], - f"{sample}:{version}", - "test", - ["X", "ygen"], - pad_3d=pad_3d, - num_samples=config["nvalid"], - ) - _logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue") + ds = PFDataset( + config["data_dir"], + f"{sample}:{version}", + "test", + ["X", "ygen"], + pad_3d=pad_3d, + num_samples=config["nvalid"], + ) + _logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue") - valid_loaders.append( - ds.get_loader( - batch_size, - 1, - rank, - use_cuda=use_cuda, - num_workers=config["num_workers"], - prefetch_factor=config["prefetch_factor"], - ) + valid_loaders.append( + ds.get_loader( + batch_size, + world_size, + rank, + use_cuda=use_cuda, + num_workers=config["num_workers"], + prefetch_factor=config["prefetch_factor"], ) + ) - valid_loader = InterleavedIterator(valid_loaders) - else: - valid_loader = None + valid_loader = InterleavedIterator(valid_loaders) train_mlpf( rank, @@ -216,7 +211,7 @@ def run(rank, world_size, config, args, outdir, logfile): f"{sample}:{version}", "test", ["X", "ygen", "ycand"], - pad_3d=pad_3d, + pad_3d=False, # in inference, use sparse dataset num_samples=config["ntest"], ) _logger.info(f"test_dataset: {ds}, {len(ds)}", color="blue") diff --git a/mlpf/tfmodel/model.py b/mlpf/tfmodel/model.py index 10e9c7ee4..7714b1c28 100644 --- a/mlpf/tfmodel/model.py +++ b/mlpf/tfmodel/model.py @@ -258,8 +258,9 @@ def call(self, X): dtype=X.dtype, ) - tf.debugging.assert_greater_equal(X[:, :, 1], 0.0, message="pt", summarize=100) - tf.debugging.assert_greater_equal(X[:, :, 5], 0.0, message="energy", summarize=100) + if DEBUGGING: + tf.debugging.assert_greater_equal(X[:, :, 1], 0.0, message="pt", summarize=100) + tf.debugging.assert_greater_equal(X[:, :, 5], 0.0, message="energy", summarize=100) Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1) Xe = tf.expand_dims(tf.math.log(X[:, :, 5] + 1.0), axis=-1) diff --git a/parameters/pyg-cms-small.yaml b/parameters/pyg-cms-small.yaml index 6d2b9b4ee..bf42c5071 100644 --- a/parameters/pyg-cms-small.yaml +++ b/parameters/pyg-cms-small.yaml @@ -56,45 +56,6 @@ valid_dataset: test_dataset: cms: - cms_pf_ttbar: - version: 1.6.0 - batch_size: 1 - cms_pf_qcd: - version: 1.6.0 - batch_size: 1 - cms_pf_ztt: - version: 1.6.0 - batch_size: 1 cms_pf_qcd_high_pt: version: 1.6.0 - batch_size: 1 - cms_pf_sms_t1tttt: - version: 1.6.0 - batch_size: 1 - cms_pf_single_electron: - version: 1.6.0 - batch_size: 20 - cms_pf_single_gamma: - version: 1.6.0 - batch_size: 20 - cms_pf_single_pi0: - version: 1.6.0 - batch_size: 20 - cms_pf_single_neutron: - version: 1.6.0 - batch_size: 20 - cms_pf_single_pi: - version: 1.6.0 - batch_size: 20 - cms_pf_single_tau: - version: 1.6.0 - batch_size: 20 - cms_pf_single_mu: - version: 1.6.0 - batch_size: 20 - cms_pf_single_proton: - version: 1.6.0 - batch_size: 20 - cms_pf_multi_particle_gun: - version: 1.6.0 - batch_size: 5 + batch_size: 10 diff --git a/scripts/tallinn/rtx/pytorch.sh b/scripts/tallinn/rtx/pytorch.sh index 5553ce416..d8e9027dd 100755 --- a/scripts/tallinn/rtx/pytorch.sh +++ b/scripts/tallinn/rtx/pytorch.sh @@ -12,4 +12,4 @@ singularity exec -B /scratch/persistent --nv \ --env PYTHONPATH=hep_tfds \ $IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1 \ --data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms-small.yaml \ - --train --conv-type gnn_lsh --num-epochs 10 --ntrain 1000 --ntest 1000 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10 + --train --conv-type gnn_lsh --num-epochs 10 --ntrain 500 --ntest 500 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10