Skip to content

Commit

Permalink
Fix pytorch inference after jpata#256 (jpata#257)
Browse files Browse the repository at this point in the history
* merge

* validation on all ranks
  • Loading branch information
jpata authored Oct 27, 2023
1 parent ae04258 commit a164f47
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 101 deletions.
13 changes: 12 additions & 1 deletion mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from .logger import _logger
from .utils import CLASS_NAMES, unpack_predictions, unpack_target
import torch_geometric
from torch_geometric.data import Batch


@torch.no_grad()
Expand All @@ -34,9 +36,18 @@ def run_predictions(rank, model, loader, sample, outpath, jetdef, jet_ptcut=5.0,

ti = time.time()
for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)):

if model.conv_type != "gravnet":
X_pad, mask = torch_geometric.utils.to_dense_batch(batch.X, batch.batch)
batch_pad = Batch(X=X_pad, mask=mask)
ypred = model(batch_pad.to(rank))
ypred = ypred[0][mask], ypred[1][mask], ypred[2][mask]
else:
ypred = model(batch.to(rank))

ygen = unpack_target(batch.ygen)
ycand = unpack_target(batch.ycand)
ypred = unpack_predictions(model(batch.to(rank)))
ypred = unpack_predictions(ypred)

for k, v in ypred.items():
ypred[k] = v.detach().cpu()
Expand Down
3 changes: 2 additions & 1 deletion mlpf/pyg/mlpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ def __init__(
self.input_dim = input_dim
self.num_convs = num_convs

self.bin_size = 640

# embedding of the inputs
if num_convs != 0:
self.nn0 = ffn(input_dim, embedding_dim, width, self.act, dropout)
self.bin_size = 640
if self.conv_type == "gravnet":
self.conv_id = nn.ModuleList()
self.conv_reg = nn.ModuleList()
Expand Down
64 changes: 38 additions & 26 deletions mlpf/pyg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train):
Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch.
"""

train_or_valid = "train" if is_train else "valid"
_logger.info(f"Initiating a train={is_train} run on device rank={rank}", color="red")

# this one will keep accumulating `train_loss` and then return the average
Expand All @@ -146,18 +147,21 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train):
else:
model.eval()

for itrain, batch in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)):
for itrain, batch in tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc=f"{train_or_valid} loop on rank={rank}"
):

if world_size > 1:
_logger.info(f"Step {itrain} on rank={rank}")

batch = batch.to(rank, non_blocking=True)

ygen = unpack_target(batch.ygen)

if is_train:
ypred = model(batch)
else:
if world_size > 1: # validation is only run on a single machine
ypred = model.module(batch)
else:
ypred = model(batch)
ypred = model(batch)

ypred = unpack_predictions(ypred)

Expand All @@ -176,11 +180,19 @@ def train_and_valid(rank, world_size, model, optimizer, data_loader, is_train):
for loss_ in epoch_loss:
epoch_loss[loss_] += loss[loss_].detach()

num_data = torch.tensor(len(data_loader), device=rank)
# sum up the number of steps from all workers
if world_size > 1:
dist.barrier()
torch.distributed.all_reduce(num_data)

for loss_ in epoch_loss:
epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / len(data_loader)
# sum up the losses from all workers
if world_size > 1:
torch.distributed.all_reduce(epoch_loss[loss_])
epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / num_data.cpu().item()

if world_size > 1:
dist.barrier()

return epoch_loss

Expand All @@ -201,7 +213,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
if (rank == 0) or (rank == "cpu"):
tensorboard_writer = SummaryWriter(f"{outdir}/runs/")
else:
tensorboard_writer = False
tensorboard_writer = None

t0_initial = time.time()

Expand Down Expand Up @@ -230,7 +242,8 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
start_epoch = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1

for epoch in range(start_epoch, num_epochs):
_logger.info(f"Initiating epoch # {epoch}", color="bold")
if (rank == 0) or (rank == "cpu"):
_logger.info(f"Initiating epoch # {epoch}", color="bold")
t0 = time.time()

# training step
Expand All @@ -244,26 +257,26 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
else:
losses_t = train_and_valid(rank, world_size, model, optimizer, train_loader, True)

losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False)

if (rank == 0) or (rank == "cpu"):
losses_v = train_and_valid(rank, world_size, model, optimizer, valid_loader, False)
if losses_v["Total"] < best_val_loss:
best_val_loss = losses_v["Total"]
stale_epochs = 0
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()

torch.save(
{"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()},
# "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format(
# outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]),
f"{outdir}/best_weights.pth",
)
else:
stale_epochs += 1

if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()

torch.save(
{"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()},
# "{outdir}/weights-{epoch:02d}-{val_loss:.6f}.pth".format(
# outdir=outdir, epoch=epoch+1, val_loss=losses_v["Total"]),
f"{outdir}/best_weights.pth",
)

if hpo:
# save model, optimizer and epoch number for HPO-supported checkpointing
if (rank == 0) or (rank == "cpu"):
Expand All @@ -287,16 +300,14 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
if stale_epochs > patience:
break

for k, v in losses_t.items():
tensorboard_writer.add_scalar(f"epoch/train_loss_rank_{rank}_" + k, v, epoch)

if (rank == 0) or (rank == "cpu"):
for k, v in losses_t.items():
tensorboard_writer.add_scalar("epoch/train_loss_" + k, v, epoch)

for loss in losses_of_interest:
losses["train"][loss].append(losses_t[loss])
losses["valid"][loss].append(losses_v[loss])

for k, v in losses_v.items():
tensorboard_writer.add_scalar("epoch/valid_loss_" + k, v, epoch)

Expand Down Expand Up @@ -340,6 +351,7 @@ def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, n
with open(f"{outdir}/mlpf_losses.pkl", "wb") as f:
pkl.dump(losses, f)

tensorboard_writer.flush()
if tensorboard_writer:
tensorboard_writer.flush()

_logger.info(f"Done with training. Total training time on device {rank} is {round((time.time() - t0_initial)/60,3)}min")
55 changes: 25 additions & 30 deletions mlpf/pyg_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(rank, world_size, config, args, outdir, logfile):
model_kwargs = pkl.load(f)

model = MLPF(**model_kwargs)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank))

Expand Down Expand Up @@ -151,38 +151,33 @@ def run(rank, world_size, config, args, outdir, logfile):

train_loader = InterleavedIterator(train_loaders)

if (rank == 0) or (rank == "cpu"): # quick validation only on a single machine
valid_loaders = []
for sample in config["valid_dataset"][config["dataset"]]:
version = config["valid_dataset"][config["dataset"]][sample]["version"]
batch_size = (
config["valid_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"]
)
valid_loaders = []
for sample in config["valid_dataset"][config["dataset"]]:
version = config["valid_dataset"][config["dataset"]][sample]["version"]
batch_size = config["valid_dataset"][config["dataset"]][sample]["batch_size"] * config["gpu_batch_multiplier"]

ds = PFDataset(
config["data_dir"],
f"{sample}:{version}",
"test",
["X", "ygen"],
pad_3d=pad_3d,
num_samples=config["nvalid"],
)
_logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue")
ds = PFDataset(
config["data_dir"],
f"{sample}:{version}",
"test",
["X", "ygen"],
pad_3d=pad_3d,
num_samples=config["nvalid"],
)
_logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue")

valid_loaders.append(
ds.get_loader(
batch_size,
1,
rank,
use_cuda=use_cuda,
num_workers=config["num_workers"],
prefetch_factor=config["prefetch_factor"],
)
valid_loaders.append(
ds.get_loader(
batch_size,
world_size,
rank,
use_cuda=use_cuda,
num_workers=config["num_workers"],
prefetch_factor=config["prefetch_factor"],
)
)

valid_loader = InterleavedIterator(valid_loaders)
else:
valid_loader = None
valid_loader = InterleavedIterator(valid_loaders)

train_mlpf(
rank,
Expand Down Expand Up @@ -216,7 +211,7 @@ def run(rank, world_size, config, args, outdir, logfile):
f"{sample}:{version}",
"test",
["X", "ygen", "ycand"],
pad_3d=pad_3d,
pad_3d=False, # in inference, use sparse dataset
num_samples=config["ntest"],
)
_logger.info(f"test_dataset: {ds}, {len(ds)}", color="blue")
Expand Down
5 changes: 3 additions & 2 deletions mlpf/tfmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ def call(self, X):
dtype=X.dtype,
)

tf.debugging.assert_greater_equal(X[:, :, 1], 0.0, message="pt", summarize=100)
tf.debugging.assert_greater_equal(X[:, :, 5], 0.0, message="energy", summarize=100)
if DEBUGGING:
tf.debugging.assert_greater_equal(X[:, :, 1], 0.0, message="pt", summarize=100)
tf.debugging.assert_greater_equal(X[:, :, 5], 0.0, message="energy", summarize=100)
Xpt = tf.expand_dims(tf.math.log(X[:, :, 1] + 1.0), axis=-1)
Xe = tf.expand_dims(tf.math.log(X[:, :, 5] + 1.0), axis=-1)

Expand Down
41 changes: 1 addition & 40 deletions parameters/pyg-cms-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,45 +56,6 @@ valid_dataset:

test_dataset:
cms:
cms_pf_ttbar:
version: 1.6.0
batch_size: 1
cms_pf_qcd:
version: 1.6.0
batch_size: 1
cms_pf_ztt:
version: 1.6.0
batch_size: 1
cms_pf_qcd_high_pt:
version: 1.6.0
batch_size: 1
cms_pf_sms_t1tttt:
version: 1.6.0
batch_size: 1
cms_pf_single_electron:
version: 1.6.0
batch_size: 20
cms_pf_single_gamma:
version: 1.6.0
batch_size: 20
cms_pf_single_pi0:
version: 1.6.0
batch_size: 20
cms_pf_single_neutron:
version: 1.6.0
batch_size: 20
cms_pf_single_pi:
version: 1.6.0
batch_size: 20
cms_pf_single_tau:
version: 1.6.0
batch_size: 20
cms_pf_single_mu:
version: 1.6.0
batch_size: 20
cms_pf_single_proton:
version: 1.6.0
batch_size: 20
cms_pf_multi_particle_gun:
version: 1.6.0
batch_size: 5
batch_size: 10
2 changes: 1 addition & 1 deletion scripts/tallinn/rtx/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ singularity exec -B /scratch/persistent --nv \
--env PYTHONPATH=hep_tfds \
$IMG python3.10 mlpf/pyg_pipeline.py --dataset cms --gpus 0,1 \
--data-dir /scratch/persistent/joosep/tensorflow_datasets --config parameters/pyg-cms-small.yaml \
--train --conv-type gnn_lsh --num-epochs 10 --ntrain 1000 --ntest 1000 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10
--train --conv-type gnn_lsh --num-epochs 10 --ntrain 500 --ntest 500 --gpu-batch-multiplier 1 --num-workers 1 --prefetch-factor 10

0 comments on commit a164f47

Please sign in to comment.