Skip to content

Commit

Permalink
pytorch backend major update (jpata#240)
Browse files Browse the repository at this point in the history
* best configs

* fix jetdef

* fix val loss high values

* fix loss on cpu bottleneck

* add num-workers and prefetch factors to args
  • Loading branch information
farakiko authored Oct 20, 2023
1 parent 9d21b93 commit 36d8584
Show file tree
Hide file tree
Showing 26 changed files with 741 additions and 1,810 deletions.
49 changes: 49 additions & 0 deletions mlpf/cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Simple script that tests if CUDA is installed on the number of gpus specefied.
Author: Farouk Mokhtar
"""

import argparse
import logging
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import torch
from pyg.logger import _logger

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()


parser.add_argument("--gpus", type=str, default="0", help="to use CPU set to empty string; else e.g., `0,1`")


def main():
args = parser.parse_args()
world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0")

if args.gpus:
assert (
world_size <= torch.cuda.device_count()
), f"--gpus is too high (specefied {world_size} gpus but only {torch.cuda.device_count()} gpus are available)"

torch.cuda.empty_cache()
if world_size > 1:
_logger.info(f"Will use torch.nn.parallel.DistributedDataParallel() and {world_size} gpus", color="purple")
for rank in range(world_size):
_logger.info(torch.cuda.get_device_name(rank), color="purple")

elif world_size == 1:
rank = 0
_logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple")

else:
rank = "cpu"
_logger.info("Will use cpu", color="purple")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ def get_class_names(dataset_name):
"delphes_ttbar_pf": r"Delphes-CMS $pp \rightarrow \mathrm{t}\overline{\mathrm{t}}$",
"delphes_qcd_pf": r"Delphes-CMS $pp \rightarrow \mathrm{QCD}$",
"cms_pf_qcd": r"CMS QCD+PU events",
"cms_pf_ztt": r"CMS Ztt events",
"cms_pf_multi_particle_gun": r"CMS multi particle gun events",
"cms_pf_single_electron": r"CMS single electron particle gun events",
"cms_pf_single_gamma": r"CMS single photon gun events",
"cms_pf_single_mu": r"CMS single muon particle gun events",
"cms_pf_single_pi": r"CMS single pion particle gun events",
"cms_pf_single_pi0": r"CMS single neutral pion particle gun events",
"cms_pf_single_proton": r"CMS single proton particle gun events",
"cms_pf_single_tau": r"CMS single tau particle gun events",
"cms_pf_sms_t1tttt": r"CMS sms t1tttt events",
}


Expand Down
14 changes: 7 additions & 7 deletions mlpf/pyg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ The current pytorch backend shares the same dataset format as the tensorflow bac

# Supervised training or testing

First make sure to update the config yaml `../../parameters/pyg_config.yaml` to your desired model parameter configuration and choice of physics samples for training and testing.
First make sure to update the config yaml e.g. `../../parameters/pyg-cms-test.yaml` to your desired model parameter configuration and choice of physics samples for training and testing.

After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same.
After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same. From the main repo run,

```bash
cd ../
python -u pyg_pipeline.py --dataset=${} --data_dir=${} --model-prefix=${} --gpus=${}
python -u mlpf/pyg_pipeline.py --dataset=${} --data_dir=${} --prefix=${} --gpus=${} --ntrain 10 --nvalid 10 --ntest 10
```
where:
- `--dataset`: choices are `cms` or `delphes` or `clic`
- `--data_dir`: path to the tensorflow_datasets (e.g. `../data/tensorflow_datasets/`)
- `--model-prefix`: path pointing to the model directory that holds the results (e.g. `../experiments/MLPF_test`)
- `--prefix`: path pointing to the model directory (note: a unique hash will be appended to avoid overwrite)
- `--gpus`: to use CPU set to empty string ""; else to use gpus provide e.g. "0,1"
- `ntrain`, `nvalid`, `ntest`: specefies number of events (per sample) that will be used

Adding the arguments:
- `--load` will load a pre-trained model
- `--train` will run a training (may train a loaded model if `--load` is provided)
- `--load` will load a pre-trained model
- `--train` will run a training (may train a loaded model if `--load` is provided)
- `--test` will run inference and save the predictions as `.parquets`
- `--make-plots` will use the predictions stored after running with `--test` to make plots for evaluation
- `--export-onnx` will export the model to ONNX
Expand Down
4 changes: 0 additions & 4 deletions mlpf/pyg/model.py → mlpf/pyg/gnn_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def point_wise_feed_forward_network(
activation="ELU",
dropout=0.0,
):

layers = []
layers.append(
nn.Linear(
Expand Down Expand Up @@ -160,7 +159,6 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node
)

def forward(self, x_msg, x_node, msk, training=False):

shp = x_msg.shape
n_points = shp[1]

Expand Down Expand Up @@ -230,7 +228,6 @@ def reverse_lsh(bins_split, points_binned_enc):

class CombinedGraphLayer(nn.Module):
def __init__(self, *args, **kwargs):

self.inout_dim = kwargs.pop("inout_dim")
self.max_num_bins = kwargs.pop("max_num_bins")
self.bin_size = kwargs.pop("bin_size")
Expand Down Expand Up @@ -274,7 +271,6 @@ def __init__(self, *args, **kwargs):
self.dropout_layer = torch.nn.Dropout(self.dropout)

def forward(self, x, msk):

n_elems = x.shape[1]
bins_to_pad_to = -torch.floor_divide(-n_elems, self.bin_size)

Expand Down
177 changes: 67 additions & 110 deletions mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import os.path as osp
import time
from pathlib import Path

Expand All @@ -24,132 +23,98 @@
)

from .logger import _logger
from .utils import CLASS_NAMES

jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)
jet_pt = 5.0
jet_match_dr = 0.1


def particle_array_to_awkward(batch_ids, arr_id, arr_p4):
ret = {
"cls_id": arr_id,
"pt": arr_p4[:, 1],
"eta": arr_p4[:, 2],
"sin_phi": arr_p4[:, 3],
"cos_phi": arr_p4[:, 4],
"energy": arr_p4[:, 5],
}
ret["phi"] = np.arctan2(ret["sin_phi"], ret["cos_phi"])
ret = awkward.from_iter([{k: ret[k][batch_ids == b] for k in ret.keys()} for b in np.unique(batch_ids)])
return ret
from .utils import CLASS_NAMES, unpack_predictions, unpack_target


@torch.no_grad()
def run_predictions(rank, mlpf, loader, sample, outpath):
def run_predictions(rank, model, loader, sample, outpath, jetdef, jet_ptcut=5.0, jet_match_dr=0.1):
"""Runs inference on the given sample and stores the output as .parquet files."""

if not osp.isdir(f"{outpath}/preds/{sample}"):
os.makedirs(f"{outpath}/preds/{sample}")
model.eval()

ti = time.time()
for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)):
ygen = unpack_target(batch.ygen)
ycand = unpack_target(batch.ycand)
ypred = unpack_predictions(model(batch.to(rank)))

for i, event in tqdm.tqdm(enumerate(loader), total=len(loader)):
event.X = event.X.to(rank)
event.batch = event.batch.to(rank)

# recall target ~ ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"]
target_ids = event.ygen[:, 0].long()
event.ygen = event.ygen[:, 1:]
for k, v in ypred.items():
ypred[k] = v.detach().cpu()

cand_ids = event.ycand[:, 0].long()
event.ycand = event.ycand[:, 1:]

# make mlpf forward pass
pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event)
pred_ids_one_hot = pred_ids_one_hot.detach().cpu()
pred_momentum = pred_momentum.detach().cpu()
pred_charge = pred_charge.detach().cpu()

pred_ids = torch.argmax(pred_ids_one_hot, axis=-1)
pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1
pred_p4 = torch.cat([pred_charge, pred_momentum], axis=-1)

batch_ids = event.batch.cpu().numpy()
awkvals = {
"gen": particle_array_to_awkward(batch_ids, target_ids.cpu().numpy(), event.ygen.cpu().numpy()),
"cand": particle_array_to_awkward(batch_ids, cand_ids.cpu().numpy(), event.ycand.cpu().numpy()),
"pred": particle_array_to_awkward(batch_ids, pred_ids.cpu().numpy(), pred_p4.cpu().numpy()),
}
# loop over the batch to disentangle the events
batch_ids = batch.batch.cpu().numpy()

gen_p4, cand_p4, pred_p4 = [], [], []
gen_cls, cand_cls, pred_cls = [], [], []
Xs = []
jets_coll = {}
Xs, p4s = [], {"gen": [], "cand": [], "pred": []}
for _ibatch in np.unique(batch_ids):
msk_batch = batch_ids == _ibatch
msk_gen = (target_ids[msk_batch] != 0).numpy()
msk_cand = (cand_ids[msk_batch] != 0).numpy()
msk_pred = (pred_ids[msk_batch] != 0).numpy()

Xs.append(event.X[msk_batch].cpu().numpy())
Xs.append(batch.X[msk_batch].cpu().numpy())

gen_p4.append(event.ygen[msk_batch, 1:][msk_gen].numpy())
gen_cls.append(target_ids[msk_batch][msk_gen].numpy())
# mask nulls for jet reconstruction
msk = (ygen["cls_id"][msk_batch] != 0).numpy()
p4s["gen"].append(ygen["p4"][msk_batch][msk].numpy())

cand_p4.append(event.ycand[msk_batch, 1:][msk_cand].numpy())
cand_cls.append(cand_ids[msk_batch][msk_cand].numpy())
msk = (ycand["cls_id"][msk_batch] != 0).numpy()
p4s["cand"].append(ycand["p4"][msk_batch][msk].numpy())

pred_p4.append(pred_momentum[msk_batch, :][msk_pred].numpy())
pred_cls.append(pred_ids[msk_batch][msk_pred].numpy())
msk = (ypred["cls_id"][msk_batch] != 0).numpy()
p4s["pred"].append(ypred["p4"][msk_batch][msk].numpy())

Xs = awkward.from_iter(Xs)
gen_p4 = awkward.from_iter(gen_p4)
gen_cls = awkward.from_iter(gen_cls)
gen_p4 = vector.awk(
awkward.zip({"pt": gen_p4[:, :, 0], "eta": gen_p4[:, :, 1], "phi": gen_p4[:, :, 2], "e": gen_p4[:, :, 3]})
)

cand_p4 = awkward.from_iter(cand_p4)
cand_cls = awkward.from_iter(cand_cls)
cand_p4 = vector.awk(
awkward.zip({"pt": cand_p4[:, :, 0], "eta": cand_p4[:, :, 1], "phi": cand_p4[:, :, 2], "e": cand_p4[:, :, 3]})
)
for typ in ["gen", "cand"]:
vec = vector.awk(
awkward.zip(
{
"pt": awkward.from_iter(p4s[typ])[:, :, 0],
"eta": awkward.from_iter(p4s[typ])[:, :, 1],
"phi": awkward.from_iter(p4s[typ])[:, :, 2],
"e": awkward.from_iter(p4s[typ])[:, :, 3],
}
)
)
cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
jets_coll[typ] = cluster.inclusive_jets(min_pt=jet_ptcut)

# in case of no predicted particles in the batch
if torch.sum(pred_ids != 0) == 0:
pt = build_dummy_array(len(pred_p4), np.float64)
eta = build_dummy_array(len(pred_p4), np.float64)
phi = build_dummy_array(len(pred_p4), np.float64)
pred_cls = build_dummy_array(len(pred_p4), np.float64)
energy = build_dummy_array(len(pred_p4), np.float64)
pred_p4 = vector.awk(awkward.zip({"pt": pt, "eta": eta, "phi": phi, "e": energy}))
if torch.sum(ypred["cls_id"] != 0) == 0:
vec = vector.awk(
awkward.zip(
{
"pt": build_dummy_array(len(p4s["pred"]), np.float64),
"eta": build_dummy_array(len(p4s["pred"]), np.float64),
"phi": build_dummy_array(len(p4s["pred"]), np.float64),
"e": build_dummy_array(len(p4s["pred"]), np.float64),
}
)
)
else:
pred_p4 = awkward.from_iter(pred_p4)
pred_cls = awkward.from_iter(pred_cls)
pred_p4 = vector.awk(
vec = vector.awk(
awkward.zip(
{
"pt": pred_p4[:, :, 0],
"eta": pred_p4[:, :, 1],
"phi": pred_p4[:, :, 2],
"e": pred_p4[:, :, 3],
"pt": awkward.from_iter(p4s["pred"])[:, :, 0],
"eta": awkward.from_iter(p4s["pred"])[:, :, 1],
"phi": awkward.from_iter(p4s["pred"])[:, :, 2],
"e": awkward.from_iter(p4s["pred"])[:, :, 3],
}
)
)

jets_coll = {}

cluster1 = fastjet.ClusterSequence(awkward.Array(gen_p4.to_xyzt()), jetdef)
jets_coll["gen"] = cluster1.inclusive_jets(min_pt=jet_pt)
cluster2 = fastjet.ClusterSequence(awkward.Array(cand_p4.to_xyzt()), jetdef)
jets_coll["cand"] = cluster2.inclusive_jets(min_pt=jet_pt)
cluster3 = fastjet.ClusterSequence(awkward.Array(pred_p4.to_xyzt()), jetdef)
jets_coll["pred"] = cluster3.inclusive_jets(min_pt=jet_pt)
cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
jets_coll["pred"] = cluster.inclusive_jets(min_pt=jet_ptcut)

gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr)
gen_to_cand = match_two_jet_collections(jets_coll, "gen", "cand", jet_match_dr)

matched_jets = awkward.Array({"gen_to_pred": gen_to_pred, "gen_to_cand": gen_to_cand})

awkvals = {
"gen": awkward.from_iter([{k: ygen[k][batch_ids == b] for k in ygen.keys()} for b in np.unique(batch_ids)]),
"cand": awkward.from_iter([{k: ycand[k][batch_ids == b] for k in ycand.keys()} for b in np.unique(batch_ids)]),
"pred": awkward.from_iter([{k: ypred[k][batch_ids == b] for k in ypred.keys()} for b in np.unique(batch_ids)]),
}

awkward.to_parquet(
awkward.Array(
{
Expand All @@ -163,9 +128,6 @@ def run_predictions(rank, mlpf, loader, sample, outpath):
)
_logger.info(f"Saved predictions at {outpath}/preds/{sample}/pred_{rank}_{i}.parquet")

if i == 100:
break

_logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min")


Expand All @@ -174,25 +136,20 @@ def make_plots(outpath, sample, dataset):

mplhep.set_style(mplhep.styles.CMS)

class_names = CLASS_NAMES[dataset]

_title = format_dataset_name(sample) # use the dataset names from the common nomenclature

if not os.path.isdir(f"{outpath}/plots/"):
os.makedirs(f"{outpath}/plots/")
os.system(f"mkdir -p {outpath}/plots/{sample}")

plots_path = Path(f"{outpath}/plots/")
plots_path = Path(f"{outpath}/plots/{sample}/")
pred_path = Path(f"{outpath}/preds/{sample}/")

yvals, X, _ = load_eval_data(str(pred_path / "*.parquet"), -1)

plot_num_elements(X, cp_dir=plots_path, title=_title)
plot_sum_energy(yvals, class_names, cp_dir=plots_path, title=_title)
plot_num_elements(X, cp_dir=plots_path, title=format_dataset_name(sample))
plot_sum_energy(yvals, CLASS_NAMES[dataset], cp_dir=plots_path, title=format_dataset_name(sample))

plot_jet_ratio(yvals, cp_dir=plots_path, title=_title)
plot_jet_ratio(yvals, cp_dir=plots_path, title=format_dataset_name(sample))

met_data = compute_met_and_ratio(yvals)
plot_met(met_data, cp_dir=plots_path, title=_title)
plot_met_ratio(met_data, cp_dir=plots_path, title=_title)
plot_met(met_data, cp_dir=plots_path, title=format_dataset_name(sample))
plot_met_ratio(met_data, cp_dir=plots_path, title=format_dataset_name(sample))

plot_particles(yvals, cp_dir=plots_path, title=_title)
plot_particles(yvals, cp_dir=plots_path, title=format_dataset_name(sample))
Loading

0 comments on commit 36d8584

Please sign in to comment.