diff --git a/mlpf/cuda_test.py b/mlpf/cuda_test.py new file mode 100644 index 000000000..4ed35d811 --- /dev/null +++ b/mlpf/cuda_test.py @@ -0,0 +1,49 @@ +""" +Simple script that tests if CUDA is installed on the number of gpus specefied. + +Author: Farouk Mokhtar +""" + +import argparse +import logging +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +import torch +from pyg.logger import _logger + +logging.basicConfig(level=logging.INFO) + +parser = argparse.ArgumentParser() + + +parser.add_argument("--gpus", type=str, default="0", help="to use CPU set to empty string; else e.g., `0,1`") + + +def main(): + args = parser.parse_args() + world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0") + + if args.gpus: + assert ( + world_size <= torch.cuda.device_count() + ), f"--gpus is too high (specefied {world_size} gpus but only {torch.cuda.device_count()} gpus are available)" + + torch.cuda.empty_cache() + if world_size > 1: + _logger.info(f"Will use torch.nn.parallel.DistributedDataParallel() and {world_size} gpus", color="purple") + for rank in range(world_size): + _logger.info(torch.cuda.get_device_name(rank), color="purple") + + elif world_size == 1: + rank = 0 + _logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple") + + else: + rank = "cpu" + _logger.info("Will use cpu", color="purple") + + +if __name__ == "__main__": + main() diff --git a/mlpf/plotting/plot_utils.py b/mlpf/plotting/plot_utils.py index dfa3ca70c..1b75a5c7d 100644 --- a/mlpf/plotting/plot_utils.py +++ b/mlpf/plotting/plot_utils.py @@ -121,6 +121,16 @@ def get_class_names(dataset_name): "delphes_ttbar_pf": r"Delphes-CMS $pp \rightarrow \mathrm{t}\overline{\mathrm{t}}$", "delphes_qcd_pf": r"Delphes-CMS $pp \rightarrow \mathrm{QCD}$", "cms_pf_qcd": r"CMS QCD+PU events", + "cms_pf_ztt": r"CMS Ztt events", + "cms_pf_multi_particle_gun": r"CMS multi particle gun events", + "cms_pf_single_electron": r"CMS single electron particle gun events", + "cms_pf_single_gamma": r"CMS single photon gun events", + "cms_pf_single_mu": r"CMS single muon particle gun events", + "cms_pf_single_pi": r"CMS single pion particle gun events", + "cms_pf_single_pi0": r"CMS single neutral pion particle gun events", + "cms_pf_single_proton": r"CMS single proton particle gun events", + "cms_pf_single_tau": r"CMS single tau particle gun events", + "cms_pf_sms_t1tttt": r"CMS sms t1tttt events", } diff --git a/mlpf/pyg/README.md b/mlpf/pyg/README.md index 9d55083ba..ab43b689f 100644 --- a/mlpf/pyg/README.md +++ b/mlpf/pyg/README.md @@ -12,23 +12,23 @@ The current pytorch backend shares the same dataset format as the tensorflow bac # Supervised training or testing -First make sure to update the config yaml `../../parameters/pyg_config.yaml` to your desired model parameter configuration and choice of physics samples for training and testing. +First make sure to update the config yaml e.g. `../../parameters/pyg-cms-test.yaml` to your desired model parameter configuration and choice of physics samples for training and testing. -After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same. +After that, the entry point to launch training or testing for either CMS, DELPHES or CLIC is the same. From the main repo run, ```bash -cd ../ -python -u pyg_pipeline.py --dataset=${} --data_dir=${} --model-prefix=${} --gpus=${} +python -u mlpf/pyg_pipeline.py --dataset=${} --data_dir=${} --prefix=${} --gpus=${} --ntrain 10 --nvalid 10 --ntest 10 ``` where: - `--dataset`: choices are `cms` or `delphes` or `clic` - `--data_dir`: path to the tensorflow_datasets (e.g. `../data/tensorflow_datasets/`) -- `--model-prefix`: path pointing to the model directory that holds the results (e.g. `../experiments/MLPF_test`) +- `--prefix`: path pointing to the model directory (note: a unique hash will be appended to avoid overwrite) - `--gpus`: to use CPU set to empty string ""; else to use gpus provide e.g. "0,1" +- `ntrain`, `nvalid`, `ntest`: specefies number of events (per sample) that will be used Adding the arguments: -- `--load` will load a pre-trained model -- `--train` will run a training (may train a loaded model if `--load` is provided) +- `--load` will load a pre-trained model +- `--train` will run a training (may train a loaded model if `--load` is provided) - `--test` will run inference and save the predictions as `.parquets` - `--make-plots` will use the predictions stored after running with `--test` to make plots for evaluation - `--export-onnx` will export the model to ONNX diff --git a/mlpf/pyg/model.py b/mlpf/pyg/gnn_lsh.py similarity index 99% rename from mlpf/pyg/model.py rename to mlpf/pyg/gnn_lsh.py index aae5b7a95..e5d4f5103 100644 --- a/mlpf/pyg/model.py +++ b/mlpf/pyg/gnn_lsh.py @@ -10,7 +10,6 @@ def point_wise_feed_forward_network( activation="ELU", dropout=0.0, ): - layers = [] layers.append( nn.Linear( @@ -160,7 +159,6 @@ def __init__(self, distance_dim=128, max_num_bins=200, bin_size=128, kernel=Node ) def forward(self, x_msg, x_node, msk, training=False): - shp = x_msg.shape n_points = shp[1] @@ -230,7 +228,6 @@ def reverse_lsh(bins_split, points_binned_enc): class CombinedGraphLayer(nn.Module): def __init__(self, *args, **kwargs): - self.inout_dim = kwargs.pop("inout_dim") self.max_num_bins = kwargs.pop("max_num_bins") self.bin_size = kwargs.pop("bin_size") @@ -274,7 +271,6 @@ def __init__(self, *args, **kwargs): self.dropout_layer = torch.nn.Dropout(self.dropout) def forward(self, x, msk): - n_elems = x.shape[1] bins_to_pad_to = -torch.floor_divide(-n_elems, self.bin_size) diff --git a/mlpf/pyg/inference.py b/mlpf/pyg/inference.py index 9524786b4..effd79a2c 100644 --- a/mlpf/pyg/inference.py +++ b/mlpf/pyg/inference.py @@ -1,5 +1,4 @@ import os -import os.path as osp import time from pathlib import Path @@ -24,132 +23,98 @@ ) from .logger import _logger -from .utils import CLASS_NAMES - -jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0) -jet_pt = 5.0 -jet_match_dr = 0.1 - - -def particle_array_to_awkward(batch_ids, arr_id, arr_p4): - ret = { - "cls_id": arr_id, - "pt": arr_p4[:, 1], - "eta": arr_p4[:, 2], - "sin_phi": arr_p4[:, 3], - "cos_phi": arr_p4[:, 4], - "energy": arr_p4[:, 5], - } - ret["phi"] = np.arctan2(ret["sin_phi"], ret["cos_phi"]) - ret = awkward.from_iter([{k: ret[k][batch_ids == b] for k in ret.keys()} for b in np.unique(batch_ids)]) - return ret +from .utils import CLASS_NAMES, unpack_predictions, unpack_target @torch.no_grad() -def run_predictions(rank, mlpf, loader, sample, outpath): +def run_predictions(rank, model, loader, sample, outpath, jetdef, jet_ptcut=5.0, jet_match_dr=0.1): """Runs inference on the given sample and stores the output as .parquet files.""" - if not osp.isdir(f"{outpath}/preds/{sample}"): - os.makedirs(f"{outpath}/preds/{sample}") + model.eval() ti = time.time() + for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)): + ygen = unpack_target(batch.ygen) + ycand = unpack_target(batch.ycand) + ypred = unpack_predictions(model(batch.to(rank))) - for i, event in tqdm.tqdm(enumerate(loader), total=len(loader)): - event.X = event.X.to(rank) - event.batch = event.batch.to(rank) - - # recall target ~ ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"] - target_ids = event.ygen[:, 0].long() - event.ygen = event.ygen[:, 1:] + for k, v in ypred.items(): + ypred[k] = v.detach().cpu() - cand_ids = event.ycand[:, 0].long() - event.ycand = event.ycand[:, 1:] - - # make mlpf forward pass - pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event) - pred_ids_one_hot = pred_ids_one_hot.detach().cpu() - pred_momentum = pred_momentum.detach().cpu() - pred_charge = pred_charge.detach().cpu() - - pred_ids = torch.argmax(pred_ids_one_hot, axis=-1) - pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1 - pred_p4 = torch.cat([pred_charge, pred_momentum], axis=-1) - - batch_ids = event.batch.cpu().numpy() - awkvals = { - "gen": particle_array_to_awkward(batch_ids, target_ids.cpu().numpy(), event.ygen.cpu().numpy()), - "cand": particle_array_to_awkward(batch_ids, cand_ids.cpu().numpy(), event.ycand.cpu().numpy()), - "pred": particle_array_to_awkward(batch_ids, pred_ids.cpu().numpy(), pred_p4.cpu().numpy()), - } + # loop over the batch to disentangle the events + batch_ids = batch.batch.cpu().numpy() - gen_p4, cand_p4, pred_p4 = [], [], [] - gen_cls, cand_cls, pred_cls = [], [], [] - Xs = [] + jets_coll = {} + Xs, p4s = [], {"gen": [], "cand": [], "pred": []} for _ibatch in np.unique(batch_ids): msk_batch = batch_ids == _ibatch - msk_gen = (target_ids[msk_batch] != 0).numpy() - msk_cand = (cand_ids[msk_batch] != 0).numpy() - msk_pred = (pred_ids[msk_batch] != 0).numpy() - Xs.append(event.X[msk_batch].cpu().numpy()) + Xs.append(batch.X[msk_batch].cpu().numpy()) - gen_p4.append(event.ygen[msk_batch, 1:][msk_gen].numpy()) - gen_cls.append(target_ids[msk_batch][msk_gen].numpy()) + # mask nulls for jet reconstruction + msk = (ygen["cls_id"][msk_batch] != 0).numpy() + p4s["gen"].append(ygen["p4"][msk_batch][msk].numpy()) - cand_p4.append(event.ycand[msk_batch, 1:][msk_cand].numpy()) - cand_cls.append(cand_ids[msk_batch][msk_cand].numpy()) + msk = (ycand["cls_id"][msk_batch] != 0).numpy() + p4s["cand"].append(ycand["p4"][msk_batch][msk].numpy()) - pred_p4.append(pred_momentum[msk_batch, :][msk_pred].numpy()) - pred_cls.append(pred_ids[msk_batch][msk_pred].numpy()) + msk = (ypred["cls_id"][msk_batch] != 0).numpy() + p4s["pred"].append(ypred["p4"][msk_batch][msk].numpy()) Xs = awkward.from_iter(Xs) - gen_p4 = awkward.from_iter(gen_p4) - gen_cls = awkward.from_iter(gen_cls) - gen_p4 = vector.awk( - awkward.zip({"pt": gen_p4[:, :, 0], "eta": gen_p4[:, :, 1], "phi": gen_p4[:, :, 2], "e": gen_p4[:, :, 3]}) - ) - cand_p4 = awkward.from_iter(cand_p4) - cand_cls = awkward.from_iter(cand_cls) - cand_p4 = vector.awk( - awkward.zip({"pt": cand_p4[:, :, 0], "eta": cand_p4[:, :, 1], "phi": cand_p4[:, :, 2], "e": cand_p4[:, :, 3]}) - ) + for typ in ["gen", "cand"]: + vec = vector.awk( + awkward.zip( + { + "pt": awkward.from_iter(p4s[typ])[:, :, 0], + "eta": awkward.from_iter(p4s[typ])[:, :, 1], + "phi": awkward.from_iter(p4s[typ])[:, :, 2], + "e": awkward.from_iter(p4s[typ])[:, :, 3], + } + ) + ) + cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef) + jets_coll[typ] = cluster.inclusive_jets(min_pt=jet_ptcut) # in case of no predicted particles in the batch - if torch.sum(pred_ids != 0) == 0: - pt = build_dummy_array(len(pred_p4), np.float64) - eta = build_dummy_array(len(pred_p4), np.float64) - phi = build_dummy_array(len(pred_p4), np.float64) - pred_cls = build_dummy_array(len(pred_p4), np.float64) - energy = build_dummy_array(len(pred_p4), np.float64) - pred_p4 = vector.awk(awkward.zip({"pt": pt, "eta": eta, "phi": phi, "e": energy})) + if torch.sum(ypred["cls_id"] != 0) == 0: + vec = vector.awk( + awkward.zip( + { + "pt": build_dummy_array(len(p4s["pred"]), np.float64), + "eta": build_dummy_array(len(p4s["pred"]), np.float64), + "phi": build_dummy_array(len(p4s["pred"]), np.float64), + "e": build_dummy_array(len(p4s["pred"]), np.float64), + } + ) + ) else: - pred_p4 = awkward.from_iter(pred_p4) - pred_cls = awkward.from_iter(pred_cls) - pred_p4 = vector.awk( + vec = vector.awk( awkward.zip( { - "pt": pred_p4[:, :, 0], - "eta": pred_p4[:, :, 1], - "phi": pred_p4[:, :, 2], - "e": pred_p4[:, :, 3], + "pt": awkward.from_iter(p4s["pred"])[:, :, 0], + "eta": awkward.from_iter(p4s["pred"])[:, :, 1], + "phi": awkward.from_iter(p4s["pred"])[:, :, 2], + "e": awkward.from_iter(p4s["pred"])[:, :, 3], } ) ) - jets_coll = {} - - cluster1 = fastjet.ClusterSequence(awkward.Array(gen_p4.to_xyzt()), jetdef) - jets_coll["gen"] = cluster1.inclusive_jets(min_pt=jet_pt) - cluster2 = fastjet.ClusterSequence(awkward.Array(cand_p4.to_xyzt()), jetdef) - jets_coll["cand"] = cluster2.inclusive_jets(min_pt=jet_pt) - cluster3 = fastjet.ClusterSequence(awkward.Array(pred_p4.to_xyzt()), jetdef) - jets_coll["pred"] = cluster3.inclusive_jets(min_pt=jet_pt) + cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef) + jets_coll["pred"] = cluster.inclusive_jets(min_pt=jet_ptcut) gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr) gen_to_cand = match_two_jet_collections(jets_coll, "gen", "cand", jet_match_dr) + matched_jets = awkward.Array({"gen_to_pred": gen_to_pred, "gen_to_cand": gen_to_cand}) + awkvals = { + "gen": awkward.from_iter([{k: ygen[k][batch_ids == b] for k in ygen.keys()} for b in np.unique(batch_ids)]), + "cand": awkward.from_iter([{k: ycand[k][batch_ids == b] for k in ycand.keys()} for b in np.unique(batch_ids)]), + "pred": awkward.from_iter([{k: ypred[k][batch_ids == b] for k in ypred.keys()} for b in np.unique(batch_ids)]), + } + awkward.to_parquet( awkward.Array( { @@ -163,9 +128,6 @@ def run_predictions(rank, mlpf, loader, sample, outpath): ) _logger.info(f"Saved predictions at {outpath}/preds/{sample}/pred_{rank}_{i}.parquet") - if i == 100: - break - _logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min") @@ -174,25 +136,20 @@ def make_plots(outpath, sample, dataset): mplhep.set_style(mplhep.styles.CMS) - class_names = CLASS_NAMES[dataset] - - _title = format_dataset_name(sample) # use the dataset names from the common nomenclature - - if not os.path.isdir(f"{outpath}/plots/"): - os.makedirs(f"{outpath}/plots/") + os.system(f"mkdir -p {outpath}/plots/{sample}") - plots_path = Path(f"{outpath}/plots/") + plots_path = Path(f"{outpath}/plots/{sample}/") pred_path = Path(f"{outpath}/preds/{sample}/") yvals, X, _ = load_eval_data(str(pred_path / "*.parquet"), -1) - plot_num_elements(X, cp_dir=plots_path, title=_title) - plot_sum_energy(yvals, class_names, cp_dir=plots_path, title=_title) + plot_num_elements(X, cp_dir=plots_path, title=format_dataset_name(sample)) + plot_sum_energy(yvals, CLASS_NAMES[dataset], cp_dir=plots_path, title=format_dataset_name(sample)) - plot_jet_ratio(yvals, cp_dir=plots_path, title=_title) + plot_jet_ratio(yvals, cp_dir=plots_path, title=format_dataset_name(sample)) met_data = compute_met_and_ratio(yvals) - plot_met(met_data, cp_dir=plots_path, title=_title) - plot_met_ratio(met_data, cp_dir=plots_path, title=_title) + plot_met(met_data, cp_dir=plots_path, title=format_dataset_name(sample)) + plot_met_ratio(met_data, cp_dir=plots_path, title=format_dataset_name(sample)) - plot_particles(yvals, cp_dir=plots_path, title=_title) + plot_particles(yvals, cp_dir=plots_path, title=format_dataset_name(sample)) diff --git a/mlpf/pyg/logger.py b/mlpf/pyg/logger.py index b61109530..cf0614a19 100644 --- a/mlpf/pyg/logger.py +++ b/mlpf/pyg/logger.py @@ -1,22 +1,18 @@ import logging -import os -import sys from functools import lru_cache -def _configLogger(name, stdout=sys.stdout, filename=None, loglevel=logging.INFO): +def _logging(rank, _logger, msg): + """Will log the message only on rank 0 or cpu.""" + if (rank == 0) or (rank == "cpu"): + _logger.info(msg) + + +def _configLogger(name, filename=None, loglevel=logging.INFO): # define a Handler which writes INFO messages or higher to the sys.stdout logger = logging.getLogger(name) logger.setLevel(loglevel) - if stdout: - console = logging.StreamHandler(stdout) - console.setLevel(loglevel) - console.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) - logger.addHandler(console) if filename: - dirname = os.path.dirname(filename) - if dirname and not os.path.exists(dirname): - os.makedirs(os.path.dirname(filename)) logfile = logging.FileHandler(filename) logfile.setLevel(loglevel) logfile.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")) diff --git a/mlpf/pyg/mlpf.py b/mlpf/pyg/mlpf.py index 157e4373e..ffcf302f8 100644 --- a/mlpf/pyg/mlpf.py +++ b/mlpf/pyg/mlpf.py @@ -4,9 +4,7 @@ import torch_geometric.utils from torch_geometric.nn.conv import GravNetConv -from .model import CombinedGraphLayer - -# from pyg_ssl.gravnet import GravNetConv # this version also returns edge index +from .gnn_lsh import CombinedGraphLayer class GravNetLayer(nn.Module): @@ -19,8 +17,6 @@ def __init__(self, embedding_dim, space_dimensions, propagate_dimensions, k, dro self.dropout = torch.nn.Dropout(dropout) def forward(self, x, batch_index): - # possibly do something with edge index - # x_new, edge_index, edge_weight = self.conv1(x, batch_index) x_new = self.conv1(x, batch_index) x_new = self.dropout(x_new) x = self.norm1(x + x_new) @@ -161,11 +157,11 @@ def forward(self, event): # assert out_stacked.shape[0] == conv_input.shape[0] embeddings_reg.append(out_stacked) + # classification embedding_id = torch.cat([input_] + embeddings_id, axis=-1) - - # predict the PIDs preds_id = self.nn_id(embedding_id) + # regression embedding_reg = torch.cat([input_] + embeddings_reg + [preds_id], axis=-1) # do some sanity checks on the PFElement input data diff --git a/mlpf/pyg/ssl/VICReg.py b/mlpf/pyg/ssl/VICReg.py deleted file mode 100644 index bafc0776a..000000000 --- a/mlpf/pyg/ssl/VICReg.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch.nn as nn -from torch_geometric.nn import global_mean_pool -from torch_geometric.nn.conv import GravNetConv - -from .utils import CLUSTERS_X, TRACKS_X, distinguish_PFelements - - -class VICReg(nn.Module): - def __init__(self, encoder, decoder): - super(VICReg, self).__init__() - self.encoder = encoder - self.decoder = decoder - - def forward(self, event): - - # seperate tracks from clusters - tracks, clusters = distinguish_PFelements(event) - - # encode to retrieve the representations - track_representations, cluster_representations = self.encoder(tracks, clusters) - - # decode/expand to get the embeddings - embedding_tracks, embedding_clusters = self.decoder(track_representations, cluster_representations) - - # global pooling to be able to compute a loss - pooled_tracks = global_mean_pool(embedding_tracks, tracks.batch) - pooled_clusters = global_mean_pool(embedding_clusters, clusters.batch) - - return pooled_tracks, pooled_clusters - - -class ENCODER(nn.Module): - """The Encoder part of VICReg which attempts to learns useful latent representations of tracks and clusters.""" - - def __init__( - self, - width=126, - embedding_dim=34, - num_convs=2, - space_dim=4, - propagate_dim=22, - k=8, - ): - super(ENCODER, self).__init__() - - self.act = nn.ELU - - # 1. different embedding of tracks/clusters - self.nn1 = nn.Sequential( - nn.Linear(TRACKS_X, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, embedding_dim), - ) - self.nn2 = nn.Sequential( - nn.Linear(CLUSTERS_X, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, embedding_dim), - ) - - # 2. same GNN for tracks/clusters - self.conv = nn.ModuleList() - for i in range(num_convs): - self.conv.append( - GravNetConv( - embedding_dim, - embedding_dim, - space_dimensions=space_dim, - propagate_dimensions=propagate_dim, - k=k, - ) - ) - - def forward(self, tracks, clusters): - - embedding_tracks = self.nn1(tracks.x.float()) - embedding_clusters = self.nn2(clusters.x.float()) - - # perform a series of graph convolutions - for num, conv in enumerate(self.conv): - embedding_tracks = conv(embedding_tracks, tracks.batch) - embedding_clusters = conv(embedding_clusters, clusters.batch) - - return embedding_tracks, embedding_clusters - - -class DECODER(nn.Module): - """The Decoder part of VICReg which attempts to expand the learned latent representations - of tracks and clusters into a space where a loss can be computed.""" - - def __init__( - self, - input_dim=34, - width=126, - output_dim=200, - ): - super(DECODER, self).__init__() - - self.act = nn.ELU - - # DECODER - self.expander = nn.Sequential( - nn.Linear(input_dim, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, width), - self.act(), - nn.Linear(width, output_dim), - ) - - def forward(self, out_tracks, out_clusters): - - return self.expander(out_tracks), self.expander(out_clusters) diff --git a/mlpf/pyg/ssl/args.py b/mlpf/pyg/ssl/args.py deleted file mode 100644 index 7b5f65c19..000000000 --- a/mlpf/pyg/ssl/args.py +++ /dev/null @@ -1,65 +0,0 @@ -import argparse - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--outpath", type=str, default="../experiments/", help="output folder") - parser.add_argument("--data_split_mode", type=str, default="mix", help="choices: ['quick', 'domain_adaptation', 'mix']") - - # samples to be used - parser.add_argument("--samples", default=-1, help="specifies samples to use") - - # directory containing datafiles - parser.add_argument("--dataset", type=str, default="CLIC", help="currently only CLIC is supported") - parser.add_argument("--data_path", type=str, default="../data/", help="path which contains the CLIC samples") - parser.add_argument("--n_train", type=int, default=-1, help="number of files to use for training") - parser.add_argument("--n_valid", type=int, default=-1, help="number of data files to use for validation") - parser.add_argument("--n_test", type=int, default=-1, help="number of data files to use for testing") - - # flag to load a pre-trained model - parser.add_argument("--load_VICReg", dest="load_VICReg", action="store_true", help="loads the model without training") - - # flag to train mlpf - parser.add_argument("--train_mlpf", dest="train_mlpf", action="store_true", help="Train MLPF") - parser.add_argument("--ssl", dest="ssl", action="store_true", help="Train ssl-based MLPF") - parser.add_argument("--native", dest="native", action="store_true", help="Train native") - - parser.add_argument("--prefix_VICReg", type=str, default=None, help="directory to hold the VICReg model") - parser.add_argument("--prefix", type=str, default="MLPF_model", help="directory to hold the mlpf model") - parser.add_argument("--overwrite", dest="overwrite", action="store_true", help="overwrites the model if True") - - # training hyperparameters - parser.add_argument("--lmbd", type=float, default=1, help="the lambda term in the VICReg loss") - parser.add_argument("--mu", type=float, default=0.1, help="the mu term in the VICReg loss") - parser.add_argument("--nu", type=float, default=1e-9, help="the nu term in the VICReg loss") - parser.add_argument("--n_epochs", type=int, default=3, help="number of training epochs for mlpf") - parser.add_argument("--n_epochs_VICReg", type=int, default=3, help="number of training epochs for VICReg") - parser.add_argument("--lr", type=float, default=5e-5, help="learning rate") - parser.add_argument("--bs", type=int, default=500, help="number of events to process at once") - parser.add_argument("--bs_VICReg", type=int, default=2000, help="number of events to process at once") - parser.add_argument("--patience", type=int, default=50, help="patience before early stopping") - - # VICReg encoder architecture - parser.add_argument("--width_encoder", type=int, default=256, help="hidden dimension of the encoder") - parser.add_argument("--embedding_dim_VICReg", type=int, default=256, help="encoded element dimension") - parser.add_argument("--num_convs_VICReg", type=int, default=3, help="number of graph convolutions") - - # VICReg decoder architecture - parser.add_argument("--width_decoder", type=int, default=256, help="hidden dimension of the decoder") - parser.add_argument("--expand_dim", type=int, default=512, help="dimension of the output of the decoder") - - # MLPF architecture - parser.add_argument("--width", type=int, default=256, help="hidden dimension of mlpf") - parser.add_argument("--embedding_dim", type=int, default=256, help="first embedding of mlpf") - parser.add_argument("--num_convs", type=int, default=3, help="number of graph layers for mlpf") - parser.add_argument("--dropout", type=float, default=0.4, help="dropout for MLPF model") - - # shared architecture - parser.add_argument("--space_dim", type=int, default=4, help="Gravnet hyperparameter") - parser.add_argument("--propagate_dim", type=int, default=22, help="Gravnet hyperparameter") - parser.add_argument("--nearest", type=int, default=32, help="k nearest neighbors") - - args = parser.parse_args() - - return args diff --git a/mlpf/pyg/ssl/evaluate.py b/mlpf/pyg/ssl/evaluate.py deleted file mode 100644 index 1185297d9..000000000 --- a/mlpf/pyg/ssl/evaluate.py +++ /dev/null @@ -1,326 +0,0 @@ -import os -import os.path as osp -import pickle as pkl -from pathlib import Path - -import awkward -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import sklearn -import sklearn.metrics -import torch -import torch_geometric -import tqdm - -from .utils import combine_PFelements, distinguish_PFelements - -matplotlib.use("Agg") - -# Ignore divide by 0 errors -np.seterr(divide="ignore", invalid="ignore") - -CLASS_TO_ID = { - "charged_hadron": 1, - "neutral_hadron": 2, - "photon": 3, - "electron": 4, - "muon": 5, -} -CLASS_NAMES_CLIC_LATEX = ["none", "Charged Hadron", "Neutral Hadron", r"$\gamma$", r"$e^\pm$", r"$\mu^\pm$"] - - -def particle_array_to_awkward(batch_ids, arr_id, arr_p4): - ret = { - "cls_id": arr_id, - "pt": arr_p4[:, 1], - "eta": arr_p4[:, 2], - "sin_phi": arr_p4[:, 3], - "cos_phi": arr_p4[:, 4], - "energy": arr_p4[:, 5], - } - ret["phi"] = np.arctan2(ret["sin_phi"], ret["cos_phi"]) - ret = awkward.from_iter([{k: ret[k][batch_ids == b] for k in ret.keys()} for b in np.unique(batch_ids)]) - return ret - - -def evaluate(device, encoder, mlpf, batch_size_mlpf, mode, outpath, samples): - import fastjet - import vector - from jet_utils import build_dummy_array, match_two_jet_collections - from plotting.plot_utils import load_eval_data, plot_jet_ratio - - jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0) - jet_pt = 5.0 - jet_match_dr = 0.1 - - npred_, ngen_, ncand_, = ( - {}, - {}, - {}, - ) - - mlpf.eval() - encoder.eval() - for sample, data in samples.items(): - print(f"Testing the {mode} model on the {sample}") - - this_out_path = f"{outpath}/{mode}/{sample}" - - if not osp.isdir(this_out_path): - os.makedirs(this_out_path) - - test_loader = torch_geometric.loader.DataLoader(data, batch_size_mlpf) - - npred, ngen, ncand = {}, {}, {} - for class_ in CLASS_TO_ID.keys(): - npred[class_], ngen[class_], ncand[class_] = [], [], [] - - mlpf.eval() - encoder.eval() - - conf_matrix = np.zeros((6, 6)) - with torch.no_grad(): - for i, batch in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)): - print(f"making predictions: {i+1}/{len(test_loader)}") - - if mode == "ssl": - # make transformation - tracks, clusters = distinguish_PFelements(batch.to(device)) - - # ENCODE - embedding_tracks, embedding_clusters = encoder(tracks, clusters) - - # concat the inputs with embeddings - tracks.x = torch.cat([batch.x[batch.x[:, 0] == 1], embedding_tracks], axis=1) - clusters.x = torch.cat([batch.x[batch.x[:, 0] == 2], embedding_clusters], axis=1) - - event = combine_PFelements(tracks, clusters) - - elif mode == "native": - event = batch - - # make mlpf forward pass - event_dev = event.to(device) - pred_ids_one_hot, pred_momentum, pred_charge = mlpf(event_dev.x, event_dev.batch) - pred_charge = torch.argmax(pred_charge, axis=-1).unsqueeze(axis=-1) - 1 - - pred_charge = torch.argmax(pred_charge, axis=1, keepdim=True) - 1 - - pred_ids = torch.argmax(pred_ids_one_hot, axis=1) - target_ids = event.ygen_id - cand_ids = event.ycand_id - - batch_ids = event.batch.cpu().numpy() - awkvals = { - "gen": particle_array_to_awkward(batch_ids, target_ids.cpu().numpy(), event.ygen.cpu().numpy()), - "cand": particle_array_to_awkward(batch_ids, cand_ids.cpu().numpy(), event.ycand.cpu().numpy()), - "pred": particle_array_to_awkward( - batch_ids, pred_ids.cpu().numpy(), torch.cat([pred_charge, pred_momentum], axis=-1).cpu().numpy() - ), - } - - gen_p4, gen_cls = [], [] - cand_p4, cand_cls = [], [] - pred_p4, pred_cls = [], [] - Xs = [] - for ibatch in np.unique(event.batch.cpu().numpy()): - msk_batch = event.batch == ibatch - msk_gen = target_ids[msk_batch] != 0 - msk_cand = cand_ids[msk_batch] != 0 - msk_pred = pred_ids[msk_batch] != 0 - - Xs.append(event.x[msk_batch].cpu().numpy()) - - gen_p4.append(event.ygen[msk_batch, 1:][msk_gen]) - gen_cls.append(target_ids[msk_batch][msk_gen]) - - cand_p4.append(event.ycand[msk_batch, 1:][msk_cand]) - cand_cls.append(cand_ids[msk_batch][msk_cand]) - - pred_p4.append(pred_momentum[msk_batch, :][msk_pred]) - pred_cls.append(pred_ids[msk_batch][msk_pred]) - - Xs = awkward.from_iter(Xs) - - gen_p4 = awkward.from_iter(gen_p4) - gen_cls = awkward.from_iter(gen_cls) - gen_p4 = vector.awk( - awkward.zip( - {"pt": gen_p4[:, :, 0], "eta": gen_p4[:, :, 1], "phi": gen_p4[:, :, 2], "e": gen_p4[:, :, 3]} - ) - ) - - cand_p4 = awkward.from_iter(cand_p4) - cand_cls = awkward.from_iter(cand_cls) - cand_p4 = vector.awk( - awkward.zip( - {"pt": cand_p4[:, :, 0], "eta": cand_p4[:, :, 1], "phi": cand_p4[:, :, 2], "e": cand_p4[:, :, 3]} - ) - ) - - # in case of no predicted particles in the batch - if torch.sum(pred_ids != 0) == 0: - pt = build_dummy_array(len(pred_p4), np.float64) - eta = build_dummy_array(len(pred_p4), np.float64) - phi = build_dummy_array(len(pred_p4), np.float64) - pred_cls = build_dummy_array(len(pred_p4), np.float64) - energy = build_dummy_array(len(pred_p4), np.float64) - pred_p4 = vector.awk(awkward.zip({"pt": pt, "eta": eta, "phi": phi, "e": energy})) - else: - pred_p4 = awkward.from_iter(pred_p4) - pred_cls = awkward.from_iter(pred_cls) - pred_p4 = vector.awk( - awkward.zip( - { - "pt": pred_p4[:, :, 0], - "eta": pred_p4[:, :, 1], - "phi": pred_p4[:, :, 2], - "e": pred_p4[:, :, 3], - } - ) - ) - - jets_coll = {} - - cluster1 = fastjet.ClusterSequence(awkward.Array(gen_p4.to_xyzt()), jetdef) - jets_coll["gen"] = cluster1.inclusive_jets(min_pt=jet_pt) - cluster2 = fastjet.ClusterSequence(awkward.Array(cand_p4.to_xyzt()), jetdef) - jets_coll["cand"] = cluster2.inclusive_jets(min_pt=jet_pt) - cluster3 = fastjet.ClusterSequence(awkward.Array(pred_p4.to_xyzt()), jetdef) - jets_coll["pred"] = cluster3.inclusive_jets(min_pt=jet_pt) - - gen_to_pred = match_two_jet_collections(jets_coll, "gen", "pred", jet_match_dr) - gen_to_cand = match_two_jet_collections(jets_coll, "gen", "cand", jet_match_dr) - matched_jets = awkward.Array({"gen_to_pred": gen_to_pred, "gen_to_cand": gen_to_cand}) - - conf_matrix += sklearn.metrics.confusion_matrix( - target_ids.detach().cpu(), - pred_ids.detach().cpu(), - labels=range(len(CLASS_NAMES_CLIC_LATEX)), - ) - - awkward.to_parquet( - awkward.Array( - { - "inputs": Xs, - "particles": awkvals, - "jets": jets_coll, - "matched_jets": matched_jets, - } - ), - f"{this_out_path}/pred_{i}.parquet", - ) - - for batch_index in range(batch_size_mlpf): - # unpack the batch - pred = pred_ids[event.batch == batch_index] - target = target_ids[event.batch == batch_index] - cand = cand_ids[event.batch == batch_index] - - for class_, id_ in CLASS_TO_ID.items(): - npred[class_].append((pred == id_).sum().item()) - ngen[class_].append((target == id_).sum().item()) - ncand[class_].append((cand == id_).sum().item()) - - make_conf_matrix(conf_matrix, outpath, mode, sample) - npred_[sample], ngen_[sample], ncand_[sample] = make_multiplicity_plots( - npred, ngen, ncand, outpath, mode, sample - ) - yvals, _, _ = load_eval_data(f"{this_out_path}/pred_*.parquet") - plot_jet_ratio(yvals, cp_dir=Path(this_out_path), title=sample) - # if i == 2: - # break - return npred_, ngen_, ncand_ - - -def make_conf_matrix(cm, outpath, mode, save_as): - import itertools - - cmap = plt.get_cmap("Blues") - cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] - cm[np.isnan(cm)] = 0.0 - - plt.figure(figsize=(8, 6)) - plt.axes() - plt.imshow(cm, interpolation="nearest", cmap=cmap) - plt.colorbar() - - thresh = cm.max() / 1.5 - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - plt.text( - j, - i, - "{:0.2f}".format(cm[i, j]), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black", - fontsize=15, - ) - if mode == "ssl": - plt.title(f"{mode} based MLPF", fontsize=25) - else: - plt.title(f"{mode} MLPF", fontsize=25) - plt.xlabel("Predicted label", fontsize=15) - plt.ylabel("True label", fontsize=15) - - plt.xticks( - range(len(CLASS_NAMES_CLIC_LATEX)), - CLASS_NAMES_CLIC_LATEX, - rotation=45, - fontsize=15, - ) - plt.yticks(range(len(CLASS_NAMES_CLIC_LATEX)), CLASS_NAMES_CLIC_LATEX, fontsize=15) - - plt.tight_layout() - - plt.savefig(f"{outpath}/conf_matrix_{mode}_{save_as}.pdf") - with open(f"{outpath}/conf_matrix_{mode}_{save_as}.pkl", "wb") as f: - pkl.dump(cm, f) - plt.close() - - -def make_multiplicity_plots(npred, ngen, ncand, outpath, mode, save_as): - for class_ in ["charged_hadron", "neutral_hadron", "photon"]: - # Plot the particle multiplicities - plt.figure() - plt.axes() - plt.scatter(ngen[class_], ncand[class_], marker=".", alpha=0.4, label="PF") - plt.scatter(ngen[class_], npred[class_], marker=".", alpha=0.4, label="MLPF") - a = 0.5 * min(np.min(npred[class_]), np.min(ngen[class_])) - b = 1.5 * max(np.max(npred[class_]), np.max(ngen[class_])) - # plt.xlim(a, b) - # plt.ylim(a, b) - plt.plot([a, b], [a, b], color="black", ls="--") - plt.title(class_) - plt.xlabel("number of truth particles") - plt.ylabel("number of reconstructed particles") - plt.legend(loc=4) - plt.savefig(f"{outpath}/multiplicity_plots_{CLASS_TO_ID[class_]}_{mode}_{save_as}.pdf") - plt.close() - - return npred, ngen, ncand - - -def make_multiplicity_plots_both(ret_ssl, ret_native, outpath): - - npred_ssl, ngen_ssl, _ = ret_ssl - npred_native, ngen_native, _ = ret_native - - for data_ in npred_ssl.keys(): - for class_ in ["charged_hadron", "neutral_hadron", "photon"]: - # Plot the particle multiplicities - plt.figure() - plt.axes() - plt.scatter(ngen_ssl[data_][class_], npred_ssl[data_][class_], marker=".", alpha=0.4, label="ssl-based MLPF") - plt.scatter(ngen_native[data_][class_], npred_native[data_][class_], marker=".", alpha=0.4, label="native MLPF") - a = 0.5 * min(np.min(npred_ssl[data_][class_]), np.min(ngen_ssl[data_][class_])) - b = 1.5 * max(np.max(npred_ssl[data_][class_]), np.max(ngen_ssl[data_][class_])) - # plt.xlim(a, b) - # plt.ylim(a, b) - plt.plot([a, b], [a, b], color="black", ls="--") - plt.title(class_) - plt.xlabel("number of truth particles") - plt.ylabel("number of reconstructed particles") - plt.legend(title=data_, loc=4) - plt.savefig(f"{outpath}/multiplicity_plots_{CLASS_TO_ID[class_]}_{data_}.pdf") - plt.close() diff --git a/mlpf/pyg/ssl/training_VICReg.py b/mlpf/pyg/ssl/training_VICReg.py deleted file mode 100644 index db4dc9de9..000000000 --- a/mlpf/pyg/ssl/training_VICReg.py +++ /dev/null @@ -1,237 +0,0 @@ -import json -import pickle as pkl -import time - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn.functional as F - -matplotlib.use("Agg") - -# Ignore divide by 0 errors -np.seterr(divide="ignore", invalid="ignore") - - -def off_diagonal(x): - """Copied from VICReg paper github https://github.com/facebookresearch/vicreg/""" - n, m = x.shape - assert n == m - return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() - - -# VICReg loss function -def criterion(tracks, clusters, loss_hparams): - """Based on the pytorch pseudocode presented at the paper in Appendix A.""" - loss_ = {} - - N = tracks.size(0) # batch size - D = tracks.size(1) # dim of representations - - # invariance loss - loss_["Invariance"] = F.mse_loss(tracks, clusters) - loss_["Invariance"] *= loss_hparams["lmbd"] - - # variance loss - std_tracks = torch.sqrt(tracks.var(dim=0) + 1e-04) - std_clusters = torch.sqrt(clusters.var(dim=0) + 1e-04) - loss_["Variance"] = torch.mean(F.relu(1 - std_tracks)) + torch.mean(F.relu(1 - std_clusters)) - loss_["Variance"] *= loss_hparams["mu"] - - # covariance loss - tracks = tracks - tracks.mean(dim=0) - clusters = clusters - clusters.mean(dim=0) - cov_tracks = (tracks.T @ tracks) / (N - 1) - cov_clusters = (clusters.T @ clusters) / (N - 1) - - # loss_["Covariance"] = ( sum_off_diagonal(cov_tracks.pow_(2)) + sum_off_diagonal(cov_clusters.pow_(2)) ) / D - loss_["Covariance"] = off_diagonal(cov_tracks).pow_(2).sum().div(D) + off_diagonal(cov_clusters).pow_(2).sum().div(D) - loss_["Covariance"] *= loss_hparams["nu"] - - return loss_ - - -@torch.no_grad() -def validation_run(multi_gpu, device, vicreg, loaders, loss_hparams): - with torch.no_grad(): - optimizer = None - ret = train(multi_gpu, device, vicreg, loaders, optimizer, loss_hparams) - return ret - - -def train(multi_gpu, device, vicreg, loaders, optimizer, loss_hparams): - """ - A training/validation run over a given epoch that gets called in the training_loop() function. - When optimizer is set to None, it freezes the model for a validation_run. - """ - - is_train = not (optimizer is None) - - if is_train: - print("---->Initiating a training run") - vicreg.train() - loader = loaders["train"] - else: - print("---->Initiating a validation run") - vicreg.eval() - loader = loaders["valid"] - - # initialize loss counters - losses_of_interest = ["Total", "Invariance", "Variance", "Covariance"] - losses = {} - for loss in losses_of_interest: - losses[loss] = 0.0 - - for i, batch in enumerate(loader): - - if multi_gpu: - X = batch - else: - X = batch.to(device) - - # run VICReg forward pass to get the embeddings - embedding_tracks, embedding_clusters = vicreg(X) - - # compute loss - loss_ = criterion(embedding_tracks, embedding_clusters, loss_hparams) - loss_["Total"] = loss_["Invariance"] + loss_["Variance"] + loss_["Covariance"] - - # update parameters - if is_train: - for param in vicreg.parameters(): - param.grad = None - loss_["Total"].backward() - optimizer.step() - - # accumulate the loss to make plots - for loss in losses_of_interest: - losses[loss] += loss_[loss].detach() - - for loss in losses_of_interest: - losses[loss] = losses[loss].cpu().item() / (len(loader)) - - return losses - - -def training_loop_VICReg(multi_gpu, device, vicreg, loaders, n_epochs, patience, optimizer, loss_hparams, outpath): - """ - Main function to perform training. Will call the train() and validation_run() functions every epoch. - - Args: - multi_gpu: flag to indicate if there's more than 1 gpu available to use. - device: "cpu" or "cuda". - vicreg: the VICReg model composed of an Encoder/Decoder. - loaders: a dict() object with keys "train" and "valid", each refering to a pytorch Dataloader. - patience: number of stale epochs allowed before stopping the training. - optimizer: optimizer to use for training (by default SGD which proved more stable). - loss_hparams: a dict() object with keys "lmbd", "u", "v" containing loss hyperparameters. - outpath: path to store the model weights and training plots - """ - - t0_initial = time.time() - - losses_of_interest = ["Total", "Invariance", "Variance", "Covariance"] - - best_val_loss, best_train_loss = {}, {} - losses = {} - losses["train"], losses["valid"] = {}, {} - for loss in losses_of_interest: - best_val_loss[loss] = 9999999.9 - losses["train"][loss] = [] - losses["valid"][loss] = [] - - stale_epochs = 0 - - for epoch in range(n_epochs): - t0 = time.time() - - if stale_epochs > patience: - print("breaking due to stale epochs") - break - - # training step - losses_t = train(multi_gpu, device, vicreg, loaders, optimizer, loss_hparams) - - for loss in losses_of_interest: - losses["train"][loss].append(losses_t[loss]) - - # validation step - losses_v = validation_run(multi_gpu, device, vicreg, loaders, loss_hparams) - - for loss in losses_of_interest: - losses["valid"][loss].append(losses_v[loss]) - - # save the lowest value of each component of the loss to print it on the legend of the loss plots - for loss in losses_of_interest: - if loss == "Total": - if losses_v[loss] < best_val_loss[loss]: - best_val_loss[loss] = losses_v[loss] - best_train_loss[loss] = losses_t[loss] - - # save the model differently if the model was wrapped with DataParallel - if multi_gpu: - state_dict = vicreg.module.state_dict() - else: - state_dict = vicreg.state_dict() - - torch.save(state_dict, f"{outpath}/VICReg_best_epoch_weights.pth") - - # dump best epoch - with open(f"{outpath}/VICReg_best_epoch.json", "w") as fp: - json.dump({"best_epoch": epoch}, fp) - - # for early-stopping purposes - stale_epochs = 0 - else: - stale_epochs += 1 - else: - if losses_v[loss] < best_val_loss[loss]: - best_val_loss[loss] = losses_v[loss] - best_train_loss[loss] = losses_t[loss] - - t1 = time.time() - - epochs_remaining = n_epochs - (epoch + 1) - time_per_epoch = (t1 - t0_initial) / (epoch + 1) - eta = epochs_remaining * time_per_epoch / 60 - - print( - f"epoch={epoch + 1} / {n_epochs} " - + f"train_loss={round(losses_t['Total'], 4)} " - + f"valid_loss={round(losses_v['Total'], 4)} " - + f"stale={stale_epochs} " - + f"time={round((t1-t0)/60, 2)}m " - + f"eta={round(eta, 1)}m" - ) - - for loss in losses_of_interest: - - # make total loss plot - fig, ax = plt.subplots() - ax.plot( - range(len(losses["train"][loss])), losses["train"][loss], label=f"training ({best_train_loss[loss]:.4f})" - ) - ax.plot( - range(len(losses["valid"][loss])), losses["valid"][loss], label=f"validation ({best_val_loss[loss]:.4f})" - ) - ax.set_xlabel("Epochs") - ax.set_ylabel(f"{loss} Loss") - ax.legend( - title=rf"VICReg - ($\lambda={loss_hparams['lmbd']} - \mu={loss_hparams['mu']} - \nu={loss_hparams['nu']}$)", - loc="best", - title_fontsize=20, - fontsize=15, - ) - plt.tight_layout() - plt.savefig(f"{outpath}/VICReg_loss_{loss}.pdf") - plt.close() - - with open(f"{outpath}/VICReg_losses.pkl", "wb") as f: - pkl.dump(losses, f) - - plt.tight_layout() - plt.close() - - print("----------------------------------------------------------") - print(f"Done with training. Total training time is {round((time.time() - t0_initial)/60,3)}min") diff --git a/mlpf/pyg/ssl/utils.py b/mlpf/pyg/ssl/utils.py deleted file mode 100644 index 0631c87f2..000000000 --- a/mlpf/pyg/ssl/utils.py +++ /dev/null @@ -1,327 +0,0 @@ -import glob -import json -import os -import os.path as osp -import pickle as pkl -import random -import shutil -import sys - -import matplotlib -import torch -from torch_geometric.data import Batch - -matplotlib.use("Agg") - -# define input dimensions -X_FEATURES_TRK = [ - "type", - "pt", - "eta", - "sin_phi", - "cos_phi", - "p", - "chi2", - "ndf", - "dEdx", - "dEdxError", - "radiusOfInnermostHit", - "tanLambda", - "D0", - "omega", - "Z0", - "time", -] -X_FEATURES_CL = [ - "type", - "et", - "eta", - "sin_phi", - "cos_phi", - "energy", - "position.x", - "position.y", - "position.z", - "iTheta", - "energy_ecal", - "energy_hcal", - "energy_other", - "num_hits", - "sigma_x", - "sigma_y", - "sigma_z", -] - -CLUSTERS_X = len(X_FEATURES_CL) - 1 # remove the `type` feature -TRACKS_X = len(X_FEATURES_TRK) - 1 # remove the `type` feature - - -def distinguish_PFelements(batch): - """Takes an event~Batch() and splits it into two Batch() objects representing the tracks/clusters.""" - - track_id = 1 - cluster_id = 2 - - tracks = Batch( - x=batch.x[batch.x[:, 0] == track_id][:, 1:].float()[ - :, :TRACKS_X - ], # remove the first input feature which is not needed anymore - ygen=batch.ygen[batch.x[:, 0] == track_id], - ygen_id=batch.ygen_id[batch.x[:, 0] == track_id], - ycand=batch.ycand[batch.x[:, 0] == track_id], - ycand_id=batch.ycand_id[batch.x[:, 0] == track_id], - batch=batch.batch[batch.x[:, 0] == track_id], - ) - clusters = Batch( - x=batch.x[batch.x[:, 0] == cluster_id][:, 1:].float()[ - :, :CLUSTERS_X - ], # remove the first input feature which is not needed anymore - ygen=batch.ygen[batch.x[:, 0] == cluster_id], - ygen_id=batch.ygen_id[batch.x[:, 0] == cluster_id], - ycand=batch.ycand[batch.x[:, 0] == cluster_id], - ycand_id=batch.ycand_id[batch.x[:, 0] == cluster_id], - batch=batch.batch[batch.x[:, 0] == cluster_id], - ) - return tracks, clusters - - -def combine_PFelements(tracks, clusters): - """Takes two Batch() objects represeting the learned latent representation of - tracks and the clusters and combines them under a single event~Batch().""" - - event = Batch( - x=torch.cat([tracks.x, clusters.x]), - ygen=torch.cat([tracks.ygen, clusters.ygen]), - ygen_id=torch.cat([tracks.ygen_id, clusters.ygen_id]), - ycand=torch.cat([tracks.ycand, clusters.ycand]), - ycand_id=torch.cat([tracks.ycand_id, clusters.ycand_id]), - batch=torch.cat([tracks.batch, clusters.batch]), - ) - - return event - - -def load_VICReg(device, outpath): - - print("Loading a previously trained model..") - vicreg_state_dict = torch.load(f"{outpath}/VICReg_best_epoch_weights.pth", map_location=device) - - with open(f"{outpath}/encoder_model_kwargs.pkl", "rb") as f: - encoder_model_kwargs = pkl.load(f) - with open(f"{outpath}/decoder_model_kwargs.pkl", "rb") as f: - decoder_model_kwargs = pkl.load(f) - - return vicreg_state_dict, encoder_model_kwargs, decoder_model_kwargs - - -def save_VICReg(args, outpath, encoder, encoder_model_kwargs, decoder, decoder_model_kwargs): - - num_encoder_parameters = sum(p.numel() for p in encoder.parameters() if p.requires_grad) - num_decoder_parameters = sum(p.numel() for p in decoder.parameters() if p.requires_grad) - - print(f"Num of 'encoder' parameters: {num_encoder_parameters}") - print(f"Num of 'decoder' parameters: {num_decoder_parameters}") - - if not osp.isdir(outpath): - os.makedirs(outpath) - - else: # if directory already exists - if not args.overwrite: # if not overwrite then exit - print("model already exists, please delete it") - sys.exit(0) - - print("model already exists, deleting it") - - filelist = [f for f in os.listdir(outpath) if not f.endswith(".txt")] # don't remove the newly created logs.txt - for f in filelist: - try: - shutil.rmtree(os.path.join(outpath, f)) - except Exception: - os.remove(os.path.join(outpath, f)) - - with open(f"{outpath}/encoder_model_kwargs.pkl", "wb") as f: # dump model architecture - pkl.dump(encoder_model_kwargs, f, protocol=pkl.HIGHEST_PROTOCOL) - with open(f"{outpath}/decoder_model_kwargs.pkl", "wb") as f: # dump model architecture - pkl.dump(decoder_model_kwargs, f, protocol=pkl.HIGHEST_PROTOCOL) - - with open(f"{outpath}/hyperparameters.json", "w") as fp: # dump hyperparameters - json.dump( - { - "data_split_mode": args.data_split_mode, - "n_epochs": args.n_epochs_VICReg, - "lr": args.lr, - "bs_VICReg": args.bs_VICReg, - "width_encoder": args.width_encoder, - "embedding_dim": args.embedding_dim_VICReg, - "num_convs": args.num_convs, - "space_dim": args.space_dim, - "propagate_dim": args.propagate_dim, - "k": args.nearest, - "width_decoder": args.width_decoder, - "output_dim": args.expand_dim, - "lmbd": args.lmbd, - "mu": args.mu, - "nu": args.nu, - "num_encoder_parameters": num_encoder_parameters, - "num_decoder_parameters": num_decoder_parameters, - }, - fp, - ) - - -def save_MLPF(args, outpath, mlpf, mlpf_model_kwargs, mode): - """ - Saves the mlpf model in the `outpath` provided. - Dumps the hyperparameters of the mlpf model in a json file. - - Args - mode: choices are "ssl" or "native" - """ - - num_mlpf_parameters = sum(p.numel() for p in mlpf.parameters() if p.requires_grad) - print(f"Num of '{mode}-mlpf' parameters: {num_mlpf_parameters}") - - if not osp.isdir(outpath): - os.makedirs(outpath) - - else: # if directory already exists - filelist = [f for f in os.listdir(outpath) if not f.endswith(".txt")] # don't remove the newly created logs.txt - - for f in filelist: - try: - shutil.rmtree(os.path.join(outpath, f)) - except Exception: - os.remove(os.path.join(outpath, f)) - - with open(f"{outpath}/mlpf_model_kwargs.pkl", "wb") as f: # dump model architecture - pkl.dump(mlpf_model_kwargs, f, protocol=pkl.HIGHEST_PROTOCOL) - - with open(f"{outpath}/hyperparameters.json", "w") as fp: # dump hyperparameters - json.dump( - { - "data_split_mode": args.data_split_mode, - "n_epochs": args.n_epochs_mlpf, - "lr": args.lr, - "bs_mlpf": args.bs_mlpf, - "width": args.width_mlpf, - "embedding_dim": args.embedding_dim_mlpf, - "num_convs": args.num_convs, - "space_dim": args.space_dim, - "propagate_dim": args.propagate_dim, - "k": args.nearest, - "mode": mode, - "num_mlpf_parameters": num_mlpf_parameters, - }, - fp, - ) - - -def data_split(data_path, data_split_mode): - """ - Depending on the data split mode chosen, the function returns different data splits. - - Choices for data_split_mode - 1. `quick`: uses only 1 datafile of each sample for quick debugging. Nothing interesting there. - 2. `domain_adaptation`: uses qq samples to train/validate VICReg and TTbar samples to train/validate MLPF. - 3. `mix`: uses a mix of both qq and TTbar samples to train/validate VICReg and MLPF. - - Returns (each as a list) - data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qq, data_test_ttbar - - """ - print(f"Will use data split mode `{data_split_mode}`") - - if data_split_mode == "quick": - data_qq = torch.load(f"{data_path}/p8_ee_qq_ecm380/processed/data_0.pt") - data_ttbar = torch.load(f"{data_path}/p8_ee_tt_ecm380/processed/data_0.pt") - - data_test_qq = data_qq[: round(0.1 * len(data_qq))] - data_test_ttbar = data_ttbar[: round(0.1 * len(data_ttbar))] - - # label remaining data as `rem` - rem_qcd = data_qq[round(0.1 * len(data_qq)) :] - rem_ttbar = data_ttbar[round(0.1 * len(data_qq)) :] - - data_VICReg = rem_qcd[: round(0.8 * len(rem_qcd))] + rem_ttbar[: round(0.8 * len(rem_ttbar))] - data_mlpf = rem_qcd[round(0.8 * len(rem_qcd)) :] + rem_ttbar[round(0.8 * len(rem_ttbar)) :] - - # shuffle the samples after mixing (not super necessary since the DataLoaders will shuffle anyway) - random.shuffle(data_VICReg) - random.shuffle(data_mlpf) - - data_VICReg_train = data_VICReg[: round(0.9 * len(data_VICReg))] - data_VICReg_valid = data_VICReg[round(0.9 * len(data_VICReg)) :] - - data_mlpf_train = data_mlpf[: round(0.9 * len(data_mlpf))] - data_mlpf_valid = data_mlpf[round(0.9 * len(data_mlpf)) :] - - else: # actual meaningful data splits - # load the qq and ttbar samples seperately - qq_files = glob.glob(f"{data_path}/p8_ee_qq_ecm380/processed/*") - ttbar_files = glob.glob(f"{data_path}/p8_ee_tt_ecm380/processed/*") - - data_qq = [] - for file in list(qq_files): - data_qq += torch.load(f"{file}") - - data_ttbar = [] - for file in list(ttbar_files): - data_ttbar += torch.load(f"{file}") - - # use 10% of each sample for testing - frac_qq_test = round(0.1 * len(data_qq)) - frac_tt_test = round(0.1 * len(data_ttbar)) - data_test_qq = data_qq[:frac_qq_test] - data_test_ttbar = data_ttbar[:frac_tt_test] - - # label remaining data as `rem` - rem_qq = data_qq[frac_qq_test:] - rem_ttbar = data_ttbar[frac_tt_test:] - - frac_qq_train = round(0.8 * len(rem_qq)) - frac_tt_train = round(0.8 * len(rem_ttbar)) - - assert frac_qq_train > 0 - assert frac_qq_test > 0 - assert frac_tt_train > 0 - assert frac_tt_test > 0 - - if data_split_mode == "domain_adaptation": - """ - use remaining qq samples for VICReg with an 80-20 train-val split. - use remaining TTbar samples for MLPF with an 80-20 train-val split. - """ - data_VICReg_train = rem_qq[:frac_qq_train] - data_VICReg_valid = rem_qq[frac_qq_train:] - - data_mlpf_train = rem_ttbar[:frac_tt_train] - data_mlpf_valid = rem_ttbar[frac_tt_train:] - - elif data_split_mode == "mix": - """ - use (80% of qq + 80% of remaining TTbar) samples for VICReg with a 90-10 train-val split. - use (20% of qq + 20% of remaining TTbar) samples for MLPF with a 90-10 train-val split. - """ - data_VICReg = rem_qq[:frac_qq_train] + rem_ttbar[:frac_tt_train] - data_mlpf = rem_qq[frac_qq_train:] + rem_ttbar[frac_tt_train:] - - # shuffle the samples after mixing (not super necessary since the DataLoaders will shuffle anyway) - random.shuffle(data_VICReg) - random.shuffle(data_mlpf) - - frac_VICReg_train = round(0.9 * len(data_VICReg)) - data_VICReg_train = data_VICReg[:frac_VICReg_train] - data_VICReg_valid = data_VICReg[frac_VICReg_train:] - - frac_mlpf_train = round(0.9 * len(data_mlpf)) - data_mlpf_train = data_mlpf[:frac_mlpf_train] - data_mlpf_valid = data_mlpf[frac_mlpf_train:] - - print(f"Will use {len(data_VICReg_train)} events to train VICReg") - print(f"Will use {len(data_VICReg_valid)} events to validate VICReg") - print(f"Will use {len(data_mlpf_train)} events to train MLPF") - print(f"Will use {len(data_mlpf_valid)} events to validate MLPF") - print(f"Will use {len(data_test_ttbar)} events to test MLPF on TTbar") - print(f"Will use {len(data_test_qq)} events to test MLPF on QCD") - - return data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qq, data_test_ttbar diff --git a/mlpf/pyg/training.py b/mlpf/pyg/training.py index c4f159d2e..e0d852c03 100644 --- a/mlpf/pyg/training.py +++ b/mlpf/pyg/training.py @@ -1,4 +1,3 @@ -import json import pickle as pkl import time from typing import Optional @@ -13,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from .logger import _logger +from .utils import unpack_predictions, unpack_target # from torch.profiler import profile, record_function, ProfilerActivity @@ -20,9 +20,28 @@ # Ignore divide by 0 errors np.seterr(divide="ignore", invalid="ignore") -# keep track of step across epochs -ISTEP_GLOBAL_TRAIN = 0 -ISTEP_GLOBAL_VALID = 0 + +def mlpf_loss(y, ypred): + """ + Args + y [dict]: relevant keys are "cls_id, momentum, charge" + ypred [dict]: relevant keys are "cls_id_onehot, momentum, charge" + """ + loss = {} + loss_obj_id = FocalLoss(gamma=2.0) + loss["Classification"] = 100 * loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]) + + msk_true_particle = torch.unsqueeze((y["cls_id"] != 0).to(dtype=torch.float32), axis=-1) + + loss["Regression"] = 10 * torch.nn.functional.huber_loss( + ypred["momentum"] * msk_true_particle, y["momentum"] * msk_true_particle + ) + loss["Charge"] = torch.nn.functional.cross_entropy( + ypred["charge"] * msk_true_particle, (y["charge"] * msk_true_particle[:, 0]).to(dtype=torch.int64) + ) + + loss["Total"] = loss["Classification"] + loss["Regression"] + loss["Charge"] + return loss # from https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py @@ -106,245 +125,264 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor: return loss -@torch.no_grad() -def validation_run(rank, world_size, model, train_loader, valid_loader, tensorboard_writer=None): - with torch.no_grad(): - optimizer = None - ret = train(rank, world_size, model, train_loader, valid_loader, optimizer, tensorboard_writer) - return ret - - -def train(rank, world_size, model, train_loader, valid_loader, optimizer, tensorboard_writer=None): +def train( + rank, + world_size, + model, + optimizer, + train_loader, + valid_loader, + best_val_loss, + stale_epochs, + patience, + outpath, + tensorboard_writer=None, +): """ - A training/validation run over a given epoch that gets called in the train_mlpf() function. - When optimizer is set to None, it freezes the model for a validation_run. + Performs training over a given epoch. Will run a validation step every N_STEPS and after the last training batch. """ - global ISTEP_GLOBAL_TRAIN, ISTEP_GLOBAL_VALID - is_train = not (optimizer is None) - - loss_obj_id = FocalLoss(gamma=2.0) - if is_train: - _logger.info(f"Initiating a training run on device {rank}", color="red") - step_type = "train" - loader = train_loader - model.train() - else: - _logger.info(f"Initiating a validation run on device {rank}", color="red") - step_type = "valid" - loader = valid_loader - model.eval() + N_STEPS = 1000 + _logger.info(f"Initiating a training run on device {rank}", color="red") - # initialize loss counters - losses = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + # initialize loss counters (note: these will be reset after N_STEPS) + train_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + valid_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} - num_iterations = 0 - for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)): - num_iterations += 1 + # this one will keep accumulating `train_loss` and then return the average + epoch_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + istep = 0 + model.train() + for itrain, batch in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)): + istep += 1 if tensorboard_writer: tensorboard_writer.add_scalar( - "step_{}/num_elems".format(step_type), + "step_train/num_elems", batch.X.shape[0], - ISTEP_GLOBAL_TRAIN if is_train else ISTEP_GLOBAL_VALID, ) - event = batch.to(rank) + ygen = unpack_target(batch.to(rank).ygen) + ypred = unpack_predictions(model(batch.to(rank))) - # recall target ~ ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"] - target_ids = event.ygen[:, 0].long() - target_charge = torch.clamp((event.ygen[:, 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2 - target_momentum = event.ygen[:, 2:-1].to(dtype=torch.float32) - - # make mlpf forward pass - # t0 = time.time() - if is_train: - pred_ids_one_hot, pred_momentum, pred_charge = model(event) - else: - if world_size > 1: # validation run is only run on a single machine - pred_ids_one_hot, pred_momentum, pred_charge = model.module(event) - else: - pred_ids_one_hot, pred_momentum, pred_charge = model(event) - # print(f"{event}: {(time.time() - t0):.2f}s") - - for icls in range(pred_ids_one_hot.shape[1]): + for icls in range(ypred["cls_id_onehot"].shape[1]): if tensorboard_writer: tensorboard_writer.add_scalar( - "step_{}/num_cls_{}".format(step_type, icls), - torch.sum(target_ids == icls), - ISTEP_GLOBAL_TRAIN if is_train else ISTEP_GLOBAL_VALID, + f"step_train/num_cls_{icls}", + torch.sum(ygen["cls_id"] == icls), ) # JP: need to debug this # assert np.all(target_charge.unique().cpu().numpy() == [0, 1, 2]) + loss = mlpf_loss(ygen, ypred) + + for param in model.parameters(): + param.grad = None + loss["Total"].backward() + optimizer.step() + + for loss_ in train_loss: + train_loss[loss_] += loss[loss_].detach() + for loss_ in epoch_loss: + epoch_loss[loss_] += loss[loss_].detach() + + # run a quick validation run at intervals of N_STEPS or at the last step + if (((itrain % N_STEPS) == 0) and (itrain != 0)) or (itrain == (len(train_loader) - 1)): + if itrain == (len(train_loader) - 1): + nsteps = istep + else: + nsteps = N_STEPS + istep = 0 - loss_ = {} - # for CLASSIFYING PID - loss_["Classification"] = 100 * loss_obj_id(pred_ids_one_hot, target_ids) - # REGRESSING p4: mask the loss in cases there is no true particle - msk_true_particle = torch.unsqueeze((target_ids != 0).to(dtype=torch.float32), axis=-1) - loss_["Regression"] = 10 * torch.nn.functional.huber_loss( - pred_momentum * msk_true_particle, target_momentum * msk_true_particle - ) - # PREDICTING CHARGE - loss_["Charge"] = torch.nn.functional.cross_entropy( - pred_charge * msk_true_particle, (target_charge * msk_true_particle[:, 0]).to(dtype=torch.int64) - ) - # TOTAL LOSS - loss_["Total"] = loss_["Classification"] + loss_["Regression"] + loss_["Charge"] + if world_size > 1: + dist.barrier() - if tensorboard_writer: - tensorboard_writer.add_scalar( - "step_{}/loss".format(step_type), - loss_["Total"], - ISTEP_GLOBAL_TRAIN if is_train else ISTEP_GLOBAL_VALID, + if tensorboard_writer: + for loss_ in train_loss: + tensorboard_writer.add_scalar( + f"step_train/loss_{loss_}", + train_loss[loss_] / nsteps, + ) + tensorboard_writer.flush() + + _logger.info( + f"Rank {rank}: " + + f"train_loss_tot={train_loss['Total']/nsteps:.2f} " + + f"train_loss_id={train_loss['Classification']/nsteps:.2f} " + + f"train_loss_momentum={train_loss['Regression']/nsteps:.2f} " + + f"train_loss_charge={train_loss['Charge']/nsteps:.2f} " ) - if is_train: - for param in model.parameters(): - param.grad = None - loss_["Total"].backward() - optimizer.step() + train_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + + if (rank == 0) or (rank == "cpu"): + _logger.info(f"Initiating a quick validation run on device {rank}", color="red") + model.eval() + + valid_loss = {"Total": 0.0, "Classification": 0.0, "Regression": 0.0, "Charge": 0.0} + with torch.no_grad(): + for ival, batch in tqdm.tqdm(enumerate(valid_loader), total=len(valid_loader)): + ygen = unpack_target(batch.to(rank).ygen) + if world_size > 1: # validation is only run on a single machine + ypred = unpack_predictions(model.module(batch.to(rank))) + else: + ypred = unpack_predictions(model(batch.to(rank))) - for loss in losses: - losses[loss] += loss_[loss].detach().cpu().item() + loss = mlpf_loss(ygen, ypred) + + for loss_ in valid_loss: + valid_loss[loss_] += loss[loss_].detach() + + for loss_ in valid_loss: + valid_loss[loss_] = valid_loss[loss_].cpu().item() / len(valid_loader) + + if tensorboard_writer: + for loss_ in valid_loss: + tensorboard_writer.add_scalar( + f"step_valid/loss_{loss_}", + valid_loss[loss_], + ) + + if valid_loss["Total"] < best_val_loss: + best_val_loss = valid_loss["Total"] + + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_state_dict = model.module.state_dict() + else: + model_state_dict = model.state_dict() + + torch.save( + {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict()}, + f"{outpath}/best_weights.pth", + ) + _logger.info( + f"finished {itrain}/{len(train_loader)} iterations and saved the model at {outpath}/best_weights.pth" # noqa + ) + stale_epochs = 0 + else: + _logger.info(f"finished {itrain}/{len(train_loader)} iterations") + stale_epochs += 1 + + _logger.info( + f"Rank {rank}: " + + f"val_loss_tot={valid_loss['Total']:.2f} " + + f"val_loss_id={valid_loss['Classification']:.2f} " + + f"val_loss_momentum={valid_loss['Regression']:.2f} " + + f"val_loss_charge={valid_loss['Charge']:.2f} " + + f"best_val_loss={best_val_loss:.2f} " + + f"stale={stale_epochs} " + ) + + if stale_epochs > patience: + _logger.info("breaking due to stale epochs") + return None, None, None, stale_epochs if tensorboard_writer: tensorboard_writer.flush() - if is_train: - ISTEP_GLOBAL_TRAIN += 1 - else: - ISTEP_GLOBAL_VALID += 1 + if world_size > 1: + dist.barrier() - for loss in losses: - losses[loss] = losses[loss] / num_iterations + model.train() # prepare for next training loop - _logger.info( - f"loss_id={losses['Classification']:.4f} loss_momentum={losses['Regression']:.4f} loss_charge={losses['Charge']:.4f}" - ) + for loss_ in epoch_loss: + epoch_loss[loss_] = epoch_loss[loss_].cpu().item() / len(train_loader) - return losses + return epoch_loss, valid_loss, best_val_loss, stale_epochs -def train_mlpf(rank, world_size, model, train_loader, valid_loader, n_epochs, patience, lr, outpath): +def train_mlpf(rank, world_size, model, optimizer, train_loader, valid_loader, num_epochs, patience, outpath): """ - Will run a full training by calling train() and validation_run() every epoch. + Will run a full training by calling train(). Args: rank: 'cpu' or int representing the gpu device id - model: a pytorch model that may be wrapped by DistributedDataParallel - train_loader: a pytorch Dataloader that loads the training data in the form ~ DataBatch(X, ygen, ycands) - valid_loader: a pytorch Dataloader that loads the validation data in the form ~ DataBatch(X, ygen, ycands) + model: a pytorch model (may be wrapped by DistributedDataParallel) + train_loader: a pytorch geometric Dataloader that loads the training data in the form ~ DataBatch(X, ygen, ycands) + valid_loader: a pytorch geometric Dataloader that loads the validation data in the form ~ DataBatch(X, ygen, ycands) patience: number of stale epochs before stopping the training - lr: learning rate to use for training outpath: path to store the model weights and training plots """ - tensorboard_writer = SummaryWriter(outpath) + tensorboard_writer = SummaryWriter(f"{outpath}/runs/rank_{rank}/") t0_initial = time.time() losses_of_interest = ["Total", "Classification", "Regression"] losses = {} - losses["train"], losses["valid"], best_val_loss, best_train_loss = {}, {}, {}, {} + losses["train"], losses["valid"] = {}, {} for loss in losses_of_interest: losses["train"][loss], losses["valid"][loss] = [], [] - best_val_loss[loss] = 99999.9 stale_epochs = 0 - - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) - - for epoch in range(n_epochs): + best_val_loss = 99999.9 + for epoch in range(num_epochs): + _logger.info(f"Initiating epoch # {epoch}", color="bold") t0 = time.time() + # training step + losses_t, losses_v, best_val_loss, stale_epochs = train( + rank, + world_size, + model, + optimizer, + train_loader, + valid_loader, + best_val_loss, + stale_epochs, + patience, + outpath, + tensorboard_writer, + ) + if stale_epochs > patience: - _logger.info("breaking due to stale epochs") break - # training step - losses_t = train(rank, world_size, model, train_loader, valid_loader, optimizer, tensorboard_writer) # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: # with record_function("model_train"): # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) for k, v in losses_t.items(): tensorboard_writer.add_scalar(f"epoch/train_loss_rank_{rank}_" + k, v, epoch) - for loss in losses_of_interest: - losses["train"][loss].append(losses_t[loss]) - - # validation step on a single machine - if world_size > 1: - dist.barrier() - if (rank == 0) or (rank == "cpu"): - losses_v = validation_run(rank, world_size, model, train_loader, valid_loader, tensorboard_writer) - if world_size > 1: - dist.barrier() if (rank == 0) or (rank == "cpu"): for loss in losses_of_interest: + losses["train"][loss].append(losses_t[loss]) losses["valid"][loss].append(losses_v[loss]) for k, v in losses_v.items(): tensorboard_writer.add_scalar("epoch/valid_loss_" + k, v, epoch) tensorboard_writer.flush() - # save the lowest value of each component of the loss to print it on the legend of the loss plots - for loss in losses_of_interest: - if loss == "Total": - if losses_v[loss] < best_val_loss[loss]: - best_val_loss[loss] = losses_v[loss] - best_train_loss[loss] = losses_t[loss] - - # save the model - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - - torch.save(state_dict, f"{outpath}/best_epoch_weights.pth") - - with open(f"{outpath}/best_epoch.json", "w") as fp: # dump best epoch - json.dump({"best_epoch": epoch}, fp) - - # for early-stopping purposes - stale_epochs = 0 - else: - stale_epochs += 1 - else: - if losses_v[loss] < best_val_loss[loss]: - best_val_loss[loss] = losses_v[loss] - best_train_loss[loss] = losses_t[loss] - t1 = time.time() - epochs_remaining = n_epochs - (epoch + 1) + epochs_remaining = num_epochs - (epoch + 1) time_per_epoch = (t1 - t0_initial) / (epoch + 1) eta = epochs_remaining * time_per_epoch / 60 _logger.info( - f"Rank {rank}: epoch={epoch + 1} / {n_epochs} " - + f"train_loss={round(losses_t['Total'], 4)} " - + f"valid_loss={round(losses_v['Total'], 4)} " + f"Rank {rank}: epoch={epoch + 1} / {num_epochs} " + + f"train_loss={losses_t['Total']:.4f} " + + f"valid_loss={losses_v['Total']:.4f} " + f"stale={stale_epochs} " + f"time={round((t1-t0)/60, 2)}m " + f"eta={round(eta, 1)}m" ) - # make loss plots for loss in losses_of_interest: fig, ax = plt.subplots() + ax.plot( range(len(losses["train"][loss])), losses["train"][loss], - label="training ({:.3f})".format(best_train_loss[loss]), + label="training", ) ax.plot( range(len(losses["valid"][loss])), losses["valid"][loss], - label="validation ({:.3f})".format(best_val_loss[loss]), + label=f"validation ({best_val_loss:.3f})", ) + ax.set_xlabel("Epochs") ax.set_ylabel(f"{loss} Loss") ax.set_ylim(0.8 * losses["train"][loss][-1], 1.2 * losses["train"][loss][-1]) diff --git a/mlpf/pyg/utils.py b/mlpf/pyg/utils.py index 307bade5e..7224f6211 100644 --- a/mlpf/pyg/utils.py +++ b/mlpf/pyg/utils.py @@ -113,14 +113,58 @@ ], } -Y_FEATURES = { - "cms": ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"], - "delphes": ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"], - "clic": ["PDG", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"], -} +Y_FEATURES = ["cls_id", "charge", "pt", "eta", "sin_phi", "cos_phi", "energy", "jet_idx"] + + +def unpack_target(y): + ret = {} + ret["cls_id"] = y[:, 0].long() + ret["charge"] = torch.clamp((y[:, 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2 + + for i, feat in enumerate(Y_FEATURES): + if i >= 2: # skip the cls and charge as they are defined above + ret[feat] = y[:, i].to(dtype=torch.float32) + ret["phi"] = torch.atan2(ret["sin_phi"], ret["cos_phi"]) + + # do some sanity checks + assert torch.all(ret["pt"] >= 0.0) # pt + assert torch.all(torch.abs(ret["sin_phi"]) <= 1.0) # sin_phi + assert torch.all(torch.abs(ret["cos_phi"]) <= 1.0) # cos_phi + assert torch.all(ret["energy"] >= 0.0) # energy + + # note ~ momentum = ["pt", "eta", "sin_phi", "cos_phi", "energy"] + ret["momentum"] = y[:, 2:-1].to(dtype=torch.float32) + ret["p4"] = torch.cat( + [ret["pt"].unsqueeze(1), ret["eta"].unsqueeze(1), ret["phi"].unsqueeze(1), ret["energy"].unsqueeze(1)], axis=1 + ) + + return ret + + +def unpack_predictions(preds): + ret = {} + ret["cls_id_onehot"], ret["momentum"], ret["charge"] = preds + + # ret["charge"] = torch.argmax(ret["charge"], axis=1, keepdim=True) - 1 + # unpacking + ret["pt"] = ret["momentum"][:, 0] + ret["eta"] = ret["momentum"][:, 1] + ret["sin_phi"] = ret["momentum"][:, 2] + ret["cos_phi"] = ret["momentum"][:, 3] + ret["energy"] = ret["momentum"][:, 4] -def save_mlpf(args, mlpf, model_kwargs, outdir): + # new variables + ret["cls_id"] = torch.argmax(ret["cls_id_onehot"], axis=-1) + ret["phi"] = torch.atan2(ret["sin_phi"], ret["cos_phi"]) + ret["p4"] = torch.cat( + [ret["pt"].unsqueeze(1), ret["eta"].unsqueeze(1), ret["phi"].unsqueeze(1), ret["energy"].unsqueeze(1)], axis=1 + ) + + return ret + + +def save_HPs(args, mlpf, model_kwargs, outdir): """Simple function to store the model parameters and training hyperparameters.""" with open(f"{outdir}/model_kwargs.pkl", "wb") as f: # dump model architecture @@ -164,20 +208,18 @@ def get_distributed_sampler(self): sampler = torch.utils.data.distributed.DistributedSampler(self.ds) return sampler - def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=4): + def get_loader(self, batch_size, world_size, num_workers=None, prefetch_factor=2): if world_size > 1: + sampler = self.get_distributed_sampler() + else: + sampler = self.get_sampler() + + if num_workers is not None: return DataLoader( self.ds, batch_size=batch_size, collate_fn=Collater(self.keys_to_get), - sampler=self.get_distributed_sampler(), - ) - elif num_workers: - return DataLoader( - self.ds, - batch_size=batch_size, - collate_fn=Collater(self.keys_to_get), - sampler=self.get_sampler(), + sampler=sampler, num_workers=num_workers, prefetch_factor=prefetch_factor, ) @@ -186,7 +228,7 @@ def get_loader(self, batch_size, world_size, num_workers=0, prefetch_factor=4): self.ds, batch_size=batch_size, collate_fn=Collater(self.keys_to_get), - sampler=self.get_sampler(), + sampler=sampler, ) def __len__(self): diff --git a/mlpf/pyg_pipeline.py b/mlpf/pyg_pipeline.py index e27cc035e..0e3a38a99 100644 --- a/mlpf/pyg_pipeline.py +++ b/mlpf/pyg_pipeline.py @@ -7,36 +7,42 @@ import argparse import logging import os -from pathlib import Path +import os.path as osp import pickle as pkl +from pathlib import Path import yaml os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import fastjet import torch import torch.distributed as dist import torch.multiprocessing as mp from pyg.inference import make_plots, run_predictions -from pyg.logger import _logger +from pyg.logger import _configLogger, _logger from pyg.mlpf import MLPF from pyg.training import train_mlpf -from pyg.utils import CLASS_LABELS, X_FEATURES, PFDataset, InterleavedIterator, save_mlpf +from pyg.utils import CLASS_LABELS, X_FEATURES, InterleavedIterator, PFDataset, save_HPs from utils import create_experiment_dir logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default="parameters/pyg-config.yaml", help="yaml config") +parser.add_argument("--config", type=str, default="parameters/pyg-cms.yaml", help="yaml config") parser.add_argument("--prefix", type=str, default="test_", help="prefix appended to result dir name") -parser.add_argument("--overwrite", dest="overwrite", action="store_true", help="overwrites the model if True") parser.add_argument("--data_dir", type=str, default="/pfvol/tensorflow_datasets/", help="path to `tensorflow_datasets/`") parser.add_argument("--gpus", type=str, default="0", help="to use CPU set to empty string; else e.g., `0,1`") parser.add_argument( - "--gpu-batch-multiplier", type=int, default=1, help="Increase batch size per GPU by this constant factor" + "--gpu-batch-multiplier", type=int, default=1, help="increase batch size per GPU by this constant factor" ) parser.add_argument("--dataset", type=str, choices=["clic", "cms", "delphes"], required=True, help="which dataset?") +parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") +parser.add_argument("--ntest", type=int, default=None, help="testing samples to use, if None use entire dataset") +parser.add_argument("--nvalid", type=int, default=500, help="validation samples to use, default will use 500 events") +parser.add_argument("--num-workers", type=int, default=None, help="number of processes to load the data") +parser.add_argument("--prefetch-factor", type=int, default=2, help="number of samples to fetch & prefetch at every call") parser.add_argument("--load", type=str, default=None, help="dir from which to load a saved model") parser.add_argument("--train", action="store_true", help="initiates a training") parser.add_argument("--test", action="store_true", help="tests the model") @@ -46,11 +52,9 @@ parser.add_argument("--conv-type", type=str, default="gravnet", help="choices are ['gnn-lsh', 'gravnet', 'attention']") parser.add_argument("--make-plots", action="store_true", help="make plots of the test predictions") parser.add_argument("--export-onnx", action="store_true", help="exports the model to onnx") -parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") -parser.add_argument("--ntest", type=int, default=None, help="training samples to use, if None use entire dataset") -def run(rank, world_size, args): +def run(rank, world_size, args, outdir, logfile): """Demo function that will be passed to each gpu if (world_size > 1) else will run normally on the given device.""" if world_size > 1: @@ -58,32 +62,39 @@ def run(rank, world_size, args): os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=rank, world_size=world_size) # (nccl should be faster than gloo) + if (rank == 0) or (rank == "cpu"): # keep writing the logs + _configLogger("mlpf", filename=logfile) + with open(args.config, "r") as stream: # load config (includes: which physics samples, model params) config = yaml.safe_load(stream) if args.load: # load a pre-trained model - outdir = args.load - + outdir = args.load # in case both --load and --train are provided with open(f"{outdir}/model_kwargs.pkl", "rb") as f: model_kwargs = pkl.load(f) model = MLPF(**model_kwargs) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + + checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) - model_state = torch.load(f"{outdir}/best_epoch_weights.pth", map_location=torch.device(rank)) if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model.module.load_state_dict(model_state) + model.module.load_state_dict(checkpoint["model_state_dict"]) else: - model.load_state_dict(model_state) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if (rank == 0) or (rank == "cpu"): - _logger.info(f"Loaded model weights from {outdir}/best_epoch_weight.pth") + _logger.info(f"Loaded model weights from {outdir}/best_weights.pth") - else: # instantiate a new model + else: # instantiate a new model in the outdir created model_kwargs = { "input_dim": len(X_FEATURES[args.dataset]), "num_classes": len(CLASS_LABELS[args.dataset]), **config["model"][args.conv_type], } model = MLPF(**model_kwargs) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) model.to(rank) @@ -95,15 +106,12 @@ def run(rank, world_size, args): _logger.info(model) if args.train: - # always create a new outdir when training a model to never overwrite - # loaded weights from previous trainings if (rank == 0) or (rank == "cpu"): - outdir = create_experiment_dir(prefix=args.prefix + Path(args.config).stem + "_") - save_mlpf(args, model, model_kwargs, outdir) # save model_kwargs and hyperparameters - _logger.info("Creating experiment dir {}".format(outdir)) + save_HPs(args, model, model_kwargs, outdir) # save model_kwargs and hyperparameters + _logger.info(f"Creating experiment dir {outdir}") _logger.info(f"Model directory {outdir}", color="bold") - train_loaders, valid_loaders = [], [] + train_loaders = [] for sample in config["train_dataset"][args.dataset]: version = config["train_dataset"][args.dataset][sample]["version"] batch_size = config["train_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier @@ -111,36 +119,38 @@ def run(rank, world_size, args): ds = PFDataset(args.data_dir, f"{sample}:{version}", "train", ["X", "ygen"], num_samples=args.ntrain) _logger.info(f"train_dataset: {ds}, {len(ds)}", color="blue") - train_loaders.append(ds.get_loader(batch_size=batch_size, world_size=world_size)) + train_loaders.append(ds.get_loader(batch_size, world_size, args.num_workers, args.prefetch_factor)) - if (rank == 0) or (rank == "cpu"): # validation only on a single machine - version = config["train_dataset"][args.dataset][sample]["version"] - batch_size = config["train_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier + train_loader = InterleavedIterator(train_loaders) + + if (rank == 0) or (rank == "cpu"): # quick validation only on a single machine + valid_loaders = [] + for sample in config["valid_dataset"][args.dataset]: + version = config["valid_dataset"][args.dataset][sample]["version"] + batch_size = config["valid_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier - ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=args.ntest) + ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=args.nvalid) _logger.info(f"valid_dataset: {ds}, {len(ds)}", color="blue") - valid_loaders.append(ds.get_loader(batch_size=batch_size, world_size=1)) + valid_loaders.append(ds.get_loader(batch_size, 1, args.num_workers, args.prefetch_factor)) - train_loader = InterleavedIterator(train_loaders) - valid_loader = None - if (rank == 0) or (rank == "cpu"): # validation only on a single machine valid_loader = InterleavedIterator(valid_loaders) + else: + valid_loader = None train_mlpf( rank, world_size, model, + optimizer, train_loader, valid_loader, args.num_epochs, args.patience, - args.lr, outdir, ) if args.test: - if args.load is None: # if we don't load, we must have a newly trained model assert args.train, "Please train a model before testing, or load a model with --load" @@ -153,21 +163,35 @@ def run(rank, world_size, args): version = config["test_dataset"][args.dataset][sample]["version"] batch_size = config["test_dataset"][args.dataset][sample]["batch_size"] * args.gpu_batch_multiplier - ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], num_samples=args.ntest) + ds = PFDataset(args.data_dir, f"{sample}:{version}", "test", ["X", "ygen", "ycand"], args.ntest) _logger.info(f"test_dataset: {ds}, {len(ds)}", color="blue") - test_loaders[sample] = InterleavedIterator([ds.get_loader(batch_size=batch_size, world_size=world_size)]) + test_loaders[sample] = InterleavedIterator( + [ds.get_loader(batch_size, world_size, args.num_workers, args.prefetch_factor)] + ) + + if not osp.isdir(f"{outdir}/preds/{sample}"): + if (rank == 0) or (rank == "cpu"): + os.system(f"mkdir -p {outdir}/preds/{sample}") + + checkpoint = torch.load(f"{outdir}/best_weights.pth", map_location=torch.device(rank)) - model_state = torch.load(f"{outdir}/best_epoch_weights.pth", map_location=torch.device(rank)) if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model.module.load_state_dict(model_state) + model.module.load_state_dict(checkpoint["model_state_dict"]) else: - model.load_state_dict(model_state) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) for sample in test_loaders: _logger.info(f"Running predictions on {sample}") torch.cuda.empty_cache() - run_predictions(rank, model, test_loaders[sample], sample, outdir) + + if args.dataset == "clic": + jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0) + else: + jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4) + + run_predictions(rank, model, test_loaders[sample], sample, outdir, jetdef, jet_ptcut=5.0, jet_match_dr=0.1) if (rank == 0) or (rank == "cpu"): # make plots and export to onnx only on a single machine if args.make_plots: @@ -206,6 +230,19 @@ def main(): args = parser.parse_args() world_size = len(args.gpus.split(",")) # will be 1 for both cpu ("") and single-gpu ("0") + if args.train: # create a new outdir when training a model to never overwrite + outdir = create_experiment_dir(prefix=args.prefix + Path(args.config).stem + "_") + logfile = f"{outdir}/train.log" + _configLogger("mlpf", filename=logfile) + + os.system(f"cp {args.config} {outdir}/train-config.yaml") + else: + outdir = args.load + logfile = f"{outdir}/test.log" + _configLogger("mlpf", filename=logfile) + + os.system(f"cp {args.config} {outdir}/test-config.yaml") + if args.gpus: assert ( world_size <= torch.cuda.device_count() @@ -219,19 +256,19 @@ def main(): mp.spawn( run, - args=(world_size, args), + args=(world_size, args, outdir, logfile), nprocs=world_size, join=True, ) elif world_size == 1: rank = 0 _logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple") - run(rank, world_size, args) + run(rank, world_size, args, outdir, logfile) else: rank = "cpu" _logger.info("Will use cpu", color="purple") - run(rank, world_size, args) + run(rank, world_size, args, outdir, logfile) if __name__ == "__main__": diff --git a/mlpf/ssl_evaluate.py b/mlpf/ssl_evaluate.py deleted file mode 100644 index 73866d5dd..000000000 --- a/mlpf/ssl_evaluate.py +++ /dev/null @@ -1,102 +0,0 @@ -import datetime -import os.path as osp -import pickle as pkl -import platform - -import torch -from pyg.mlpf import MLPF -from pyg.ssl.args import parse_args -from pyg.ssl.evaluate import evaluate -from pyg.ssl.utils import data_split, load_VICReg -from pyg.ssl.VICReg import DECODER, ENCODER, VICReg - -if __name__ == "__main__": - import sys - - sys.path.append("") - - # define the global base device - if torch.cuda.device_count(): - device = torch.device("cuda:0") - print(f"Will use {torch.cuda.get_device_name(device)}") - else: - device = "cpu" - print("Will use cpu") - - args = parse_args() - - # load the clic dataset - _, _, _, _, data_test_qcd, data_test_ttbar = data_split(args.data_path, args.data_split_mode) - - # setup the directory path to hold all models and plots - if args.prefix_VICReg is None: - args.prefix_VICReg = "pyg_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + "." + platform.node() - outpath = osp.join(args.outpath, args.prefix_VICReg) - - # load a pre-trained VICReg model - vicreg_state_dict, encoder_model_kwargs, decoder_model_kwargs = load_VICReg(device, outpath) - - vicreg_encoder = ENCODER(**encoder_model_kwargs) - vicreg_decoder = DECODER(**decoder_model_kwargs) - - vicreg = VICReg(vicreg_encoder, vicreg_decoder) - - try: - vicreg.load_state_dict(vicreg_state_dict) - except RuntimeError: - # if the mlpf model was saved using torch.nn.DataParallel() - from collections import OrderedDict - - new_state_dict = OrderedDict() - for k, v in vicreg_state_dict.items(): - name = k[7:] # remove module. - new_state_dict[name] = v - vicreg_state_dict = new_state_dict - vicreg.load_state_dict(vicreg_state_dict) - - vicreg.to(device) - - # load a pre-trained MLPF model - if args.ssl: - outpath_ssl = osp.join(f"{outpath}/MLPF/", f"{args.prefix}_ssl") - - print("Loading a previously trained ssl model..") - mlpf_ssl_state_dict = torch.load(f"{outpath_ssl}/best_epoch_weights.pth", map_location=device) - - with open(f"{outpath_ssl}/model_kwargs.pkl", "rb") as f: - mlpf_model_kwargs = pkl.load(f) - - mlpf_ssl = MLPF(**mlpf_model_kwargs).to(device) - mlpf_ssl.load_state_dict(mlpf_ssl_state_dict) - - ret_ssl = evaluate( - device, - vicreg_encoder, - mlpf_ssl, - args.bs, - "ssl", - outpath_ssl, - {"QCD": data_test_qcd, "TTBar": data_test_ttbar}, - ) - - if args.native: - outpath_native = osp.join(f"{outpath}/MLPF/", f"{args.prefix}_native") - - print("Loading a previously trained ssl model..") - mlpf_native_state_dict = torch.load(f"{outpath_native}/best_epoch_weights.pth", map_location=device) - - with open(f"{outpath_native}/model_kwargs.pkl", "rb") as f: - mlpf_model_kwargs = pkl.load(f) - - mlpf_native = MLPF(**mlpf_model_kwargs).to(device) - mlpf_native.load_state_dict(mlpf_native_state_dict) - - ret_native = evaluate( - device, - vicreg_encoder, - mlpf_native, - args.bs, - "native", - outpath_native, - {"QCD": data_test_qcd, "TTBar": data_test_ttbar}, - ) diff --git a/mlpf/ssl_pipeline.py b/mlpf/ssl_pipeline.py deleted file mode 100644 index 2a6427d0d..000000000 --- a/mlpf/ssl_pipeline.py +++ /dev/null @@ -1,213 +0,0 @@ -import datetime -import os.path as osp -import platform - -import matplotlib -import mplhep -import numpy as np -import torch -import torch_geometric -from pyg.mlpf import MLPF -from pyg.ssl.args import parse_args -from pyg.ssl.training_VICReg import training_loop_VICReg -from pyg.ssl.utils import CLUSTERS_X, TRACKS_X, data_split, load_VICReg, save_VICReg -from pyg.ssl.VICReg import DECODER, ENCODER, VICReg -from pyg.training import training_loop -from pyg.utils import save_mlpf - -matplotlib.use("Agg") -mplhep.style.use(mplhep.styles.CMS) - -""" -Developing a PyTorch Geometric semi-supervised (VICReg-based https://arxiv.org/abs/2105.04906) pipeline -for particleflow reconstruction on CLIC datasets. - -Authors: Farouk Mokhtar, Joosep Pata. -""" - - -# Ignore divide by 0 errors -np.seterr(divide="ignore", invalid="ignore") - -# define the global base device(s) -if torch.cuda.device_count(): - device = torch.device("cuda:0") - print(f"Will use {torch.cuda.get_device_name(device)}") -else: - device = "cpu" - print("Will use cpu") -multi_gpu = torch.cuda.device_count() > 1 - -if __name__ == "__main__": - - args = parse_args() - - world_size = torch.cuda.device_count() - - # our data size varies from batch to batch, because each set of N_batch events has a different number of particles - torch.backends.cudnn.benchmark = False - # torch.autograd.set_detect_anomaly(True) - - # load the clic dataset - data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qcd, data_test_ttbar = data_split( - args.data_path, args.data_split_mode - ) - - # setup the directory path to hold all models and plots - if args.prefix_VICReg is None: - args.prefix_VICReg = "pyg_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + "." + platform.node() - outpath = osp.join(args.outpath, args.prefix_VICReg) - - # load a pre-trained VICReg model - if args.load_VICReg: - vicreg_state_dict, encoder_model_kwargs, decoder_model_kwargs = load_VICReg(device, outpath) - - vicreg_encoder = ENCODER(**encoder_model_kwargs) - vicreg_decoder = DECODER(**decoder_model_kwargs) - - vicreg = VICReg(vicreg_encoder, vicreg_decoder) - - try: - vicreg.load_state_dict(vicreg_state_dict) - except RuntimeError: - # if the mlpf model was saved using torch.nn.DataParallel() - from collections import OrderedDict - - new_state_dict = OrderedDict() - for k, v in vicreg_state_dict.items(): - name = k[7:] # remove module. - new_state_dict[name] = v - vicreg_state_dict = new_state_dict - vicreg.load_state_dict(vicreg_state_dict) - - vicreg.load_state_dict(vicreg_state_dict) - vicreg.to(device) - - else: - encoder_model_kwargs = { - "embedding_dim": args.embedding_dim_VICReg, - "width": args.width_encoder, - "num_convs": args.num_convs_VICReg, - "space_dim": args.space_dim, - "propagate_dim": args.propagate_dim, - "k": args.nearest, - } - - decoder_model_kwargs = { - "input_dim": args.embedding_dim_VICReg, - "output_dim": args.expand_dim, - "width": args.width_decoder, - } - - vicreg_encoder = ENCODER(**encoder_model_kwargs) - vicreg_decoder = DECODER(**decoder_model_kwargs) - vicreg = VICReg(vicreg_encoder, vicreg_decoder) - vicreg.to(device) - - # save model_kwargs and hyperparameters - save_VICReg(args, outpath, vicreg_encoder, encoder_model_kwargs, vicreg_decoder, decoder_model_kwargs) - - if multi_gpu: - vicreg = torch_geometric.nn.DataParallel(vicreg) - train_loader = torch_geometric.loader.DataListLoader(data_VICReg_train, args.bs_VICReg) - valid_loader = torch_geometric.loader.DataListLoader(data_VICReg_valid, args.bs_VICReg) - else: - train_loader = torch_geometric.loader.DataLoader(data_VICReg_train, args.bs_VICReg) - valid_loader = torch_geometric.loader.DataLoader(data_VICReg_valid, args.bs_VICReg) - - optimizer = torch.optim.SGD(vicreg.parameters(), lr=args.lr) - - print(vicreg) - print(f"VICReg model name: {args.prefix_VICReg}") - print(f"Training VICReg over {args.n_epochs_VICReg} epochs") - training_loop_VICReg( - multi_gpu, - device, - vicreg, - {"train": train_loader, "valid": valid_loader}, - args.n_epochs_VICReg, - args.patience, - optimizer, - {"lmbd": args.lmbd, "mu": args.mu, "nu": args.nu}, - outpath, - ) - - if args.train_mlpf: - print("------> Progressing to MLPF trainings...") - print(f"Will use {len(data_mlpf_train)} events for train") - print(f"Will use {len(data_mlpf_valid)} events for valid") - - train_loader = [torch_geometric.loader.DataLoader(data_mlpf_train, args.bs)] - valid_loader = [torch_geometric.loader.DataLoader(data_mlpf_valid, args.bs)] - - input_ = max(CLUSTERS_X, TRACKS_X) + 1 # max cz we pad when we concatenate them & +1 cz there's the `type` feature - - if args.ssl: - - mlpf_model_kwargs = { - "input_dim": input_, - "embedding_dim": args.embedding_dim, - "width": args.width, - "num_convs": args.num_convs, - "k": args.nearest, - "dropout": args.dropout, - "ssl": True, - "VICReg_embedding_dim": args.embedding_dim_VICReg, - } - - mlpf_ssl = MLPF(**mlpf_model_kwargs).to(device) - print(mlpf_ssl) - print(f"MLPF model name: {args.prefix}_ssl") - print(f"Will use VICReg model {args.prefix_VICReg}") - - # make mlpf specific directory - outpath_ssl = osp.join(f"{outpath}/MLPF/", f"{args.prefix}_ssl") - save_mlpf(args, outpath_ssl, mlpf_ssl, mlpf_model_kwargs, mode="ssl") - - print(f"- Training ssl based MLPF over {args.n_epochs} epochs") - - training_loop( - device, - mlpf_ssl, - train_loader, - valid_loader, - args.bs, - args.n_epochs, - args.patience, - args.lr, - outpath_ssl, - vicreg_encoder, - ) - - if args.native: - - mlpf_model_kwargs = { - "input_dim": input_, - "embedding_dim": args.embedding_dim, - "width": args.width, - "num_convs": args.num_convs, - "k": args.nearest, - "dropout": args.dropout, - } - - mlpf_native = MLPF(**mlpf_model_kwargs).to(device) - print(mlpf_native) - print(f"MLPF model name: {args.prefix}_native") - - # make mlpf specific directory - outpath_native = osp.join(f"{outpath}/MLPF/", f"{args.prefix}_native") - save_mlpf(args, outpath_native, mlpf_native, mlpf_model_kwargs, mode="native") - - print(f"- Training native MLPF over {args.n_epochs} epochs") - - training_loop( - device, - mlpf_native, - train_loader, - valid_loader, - args.bs, - args.n_epochs, - args.patience, - args.lr, - outpath_native, - ) diff --git a/mlpf/utils.py b/mlpf/utils.py index 9ff96be60..624906a4d 100644 --- a/mlpf/utils.py +++ b/mlpf/utils.py @@ -1,6 +1,6 @@ -from pathlib import Path import datetime import platform +from pathlib import Path def create_experiment_dir(prefix=None, suffix=None): @@ -13,4 +13,5 @@ def create_experiment_dir(prefix=None, suffix=None): train_dir = train_dir.with_name(train_dir.name + "." + platform.node()) train_dir.mkdir(parents=True) + return str(train_dir) diff --git a/parameters/pyg-clic.yaml b/parameters/pyg-clic.yaml new file mode 100644 index 000000000..642a911fa --- /dev/null +++ b/parameters/pyg-clic.yaml @@ -0,0 +1,68 @@ +backend: pytorch + +model: + gnn-lsh: + conv_type: gnn-lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 22 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + +train_dataset: + clic: + clic_edm_qq_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ttbar_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ttbar_pu10_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ww_fullhad_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_zh_tautau_pf: + version: 1.5.0 + batch_size: 100 + +valid_dataset: + clic: + clic_edm_qq_pf: + version: 1.5.0 + batch_size: 100 + +test_dataset: + clic: + clic_edm_qq_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ttbar_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ttbar_pu10_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_ww_fullhad_pf: + version: 1.5.0 + batch_size: 100 + clic_edm_zh_tautau_pf: + version: 1.5.0 + batch_size: 100 diff --git a/parameters/pyg-cms-small.yaml b/parameters/pyg-cms-small.yaml new file mode 100644 index 000000000..56e8773af --- /dev/null +++ b/parameters/pyg-cms-small.yaml @@ -0,0 +1,86 @@ +backend: pytorch + +model: + gnn-lsh: + conv_type: gnn-lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 22 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + +train_dataset: + cms: + cms_pf_ttbar: + version: 1.6.0 + batch_size: 1 + cms_pf_qcd: + version: 1.6.0 + batch_size: 1 + +valid_dataset: + cms: + cms_pf_qcd_high_pt: + version: 1.6.0 + batch_size: 1 + +test_dataset: + cms: + cms_pf_ttbar: + version: 1.6.0 + batch_size: 1 + cms_pf_qcd: + version: 1.6.0 + batch_size: 1 + cms_pf_ztt: + version: 1.6.0 + batch_size: 1 + cms_pf_qcd_high_pt: + version: 1.6.0 + batch_size: 1 + cms_pf_sms_t1tttt: + version: 1.6.0 + batch_size: 1 + cms_pf_single_electron: + version: 1.6.0 + batch_size: 20 + cms_pf_single_gamma: + version: 1.6.0 + batch_size: 20 + cms_pf_single_pi0: + version: 1.6.0 + batch_size: 20 + cms_pf_single_neutron: + version: 1.6.0 + batch_size: 20 + cms_pf_single_pi: + version: 1.6.0 + batch_size: 20 + cms_pf_single_tau: + version: 1.6.0 + batch_size: 20 + cms_pf_single_mu: + version: 1.6.0 + batch_size: 20 + cms_pf_single_proton: + version: 1.6.0 + batch_size: 20 + cms_pf_multi_particle_gun: + version: 1.6.0 + batch_size: 5 diff --git a/parameters/pyg-cms-test-qcdhighpt.yaml b/parameters/pyg-cms-test-qcdhighpt.yaml new file mode 100644 index 000000000..dad102a0b --- /dev/null +++ b/parameters/pyg-cms-test-qcdhighpt.yaml @@ -0,0 +1,32 @@ +backend: pytorch + +model: + gnn-lsh: + conv_type: gnn-lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 22 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + +test_dataset: + cms: + cms_pf_qcd_high_pt: + version: 1.6.0 + batch_size: 1 diff --git a/parameters/pyg-config.yaml b/parameters/pyg-cms.yaml similarity index 67% rename from parameters/pyg-config.yaml rename to parameters/pyg-cms.yaml index affe7134e..4cc2ff30e 100644 --- a/parameters/pyg-config.yaml +++ b/parameters/pyg-cms.yaml @@ -10,9 +10,9 @@ model: gravnet: conv_type: gravnet - embedding_dim: 36 - width: 128 - num_convs: 2 + embedding_dim: 512 + width: 512 + num_convs: 3 k: 16 propagate_dimensions: 22 space_dimensions: 4 @@ -69,29 +69,12 @@ train_dataset: cms_pf_multi_particle_gun: version: 1.6.0 batch_size: 5 - clic: - clic_edm_qq_pf: - version: 1.5.0 - batch_size: 100 - clic_edm_ttbar_pf: - version: 1.5.0 - batch_size: 100 - # clic_edm_ttbar_pu10_pf: - # version: 1.5.0 - # batch_size: 100 - # clic_edm_ww_fullhad_pf: - # version: 1.5.0 - # batch_size: 100 - # clic_edm_zh_tautau_pf: - # version: 1.5.0 - # batch_size: 100 - delphes: - delphes_ttbar_pf: - version: 1.2.0 - batch_size: 10 - delphes_qcd_pf: - version: 1.2.0 - batch_size: 10 + +valid_dataset: + cms: + cms_pf_qcd_high_pt: + version: 1.6.0 + batch_size: 1 test_dataset: cms: @@ -137,26 +120,3 @@ test_dataset: cms_pf_multi_particle_gun: version: 1.6.0 batch_size: 5 - clic: - clic_edm_qq_pf: - version: 1.5.0 - batch_size: 100 - clic_edm_ttbar_pf: - version: 1.5.0 - batch_size: 100 - # clic_edm_ttbar_pu10_pf: - # version: 1.5.0 - # batch_size: 100 - # clic_edm_ww_fullhad_pf: - # version: 1.5.0 - # batch_size: 100 - # clic_edm_zh_tautau_pf: - # version: 1.5.0 - # batch_size: 100 - delphes: - delphes_ttbar_pf: - version: 1.2.0 - batch_size: 10 - delphes_qcd_pf: - version: 1.2.0 - batch_size: 10 diff --git a/parameters/pyg-delphes.yaml b/parameters/pyg-delphes.yaml new file mode 100644 index 000000000..42b09495b --- /dev/null +++ b/parameters/pyg-delphes.yaml @@ -0,0 +1,50 @@ +backend: pytorch + +model: + gnn-lsh: + conv_type: gnn-lsh + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 + + gravnet: + conv_type: gravnet + embedding_dim: 512 + width: 512 + num_convs: 3 + k: 16 + propagate_dimensions: 22 + space_dimensions: 4 + dropout: 0.0 + + attention: + conv_type: attention + embedding_dim: 128 + width: 128 + num_convs: 2 + dropout: 0.0 + +train_dataset: + delphes: + delphes_ttbar_pf: + version: 1.2.0 + batch_size: 10 + delphes_qcd_pf: + version: 1.2.0 + batch_size: 10 + +valid_dataset: + delphes: + delphes_qcd_pf: + version: 1.2.0 + batch_size: 10 + +test_dataset: + delphes: + delphes_ttbar_pf: + version: 1.2.0 + batch_size: 10 + delphes_qcd_pf: + version: 1.2.0 + batch_size: 10 diff --git a/parameters/pyg-config-test.yaml b/parameters/pyg-workflow-test.yaml similarity index 58% rename from parameters/pyg-config-test.yaml rename to parameters/pyg-workflow-test.yaml index f3d75defa..eafc3cb3d 100644 --- a/parameters/pyg-config-test.yaml +++ b/parameters/pyg-workflow-test.yaml @@ -3,39 +3,38 @@ backend: pytorch model: gnn-lsh: conv_type: gnn-lsh - embedding_dim: 36 - width: 128 - num_convs: 2 - k: 16 - propagate_dimensions: 22 - space_dimensions: 4 - dropout: True + embedding_dim: 512 + width: 512 + num_convs: 3 + dropout: 0.0 gravnet: conv_type: gravnet - embedding_dim: 36 - width: 128 - num_convs: 2 + embedding_dim: 512 + width: 512 + num_convs: 3 k: 16 propagate_dimensions: 22 space_dimensions: 4 - dropout: True + dropout: 0.0 attention: conv_type: attention embedding_dim: 128 width: 128 num_convs: 2 - k: 16 - propagate_dimensions: 22 - space_dimensions: 4 - dropout: True + dropout: 0.0 train_dataset: cms: cms_pf_ttbar: version: 1.6.0 batch_size: 2 +valid_dataset: + cms: + cms_pf_ttbar: + version: 1.6.0 + batch_size: 2 test_dataset: cms: cms_pf_ttbar: diff --git a/scripts/local_test_pyg.sh b/scripts/local_test_pyg.sh index a9f157906..7ba0e80ab 100755 --- a/scripts/local_test_pyg.sh +++ b/scripts/local_test_pyg.sh @@ -27,4 +27,4 @@ mkdir -p experiments tfds build mlpf/heptfds/cms_pf/ttbar --manual_dir ./local_test_data -python mlpf/pyg_pipeline.py --config parameters/pyg-config-test.yaml --dataset cms --data_dir ./tensorflow_datasets/ --prefix MLPF_test --gpus "" --train --test --make-plots +python mlpf/pyg_pipeline.py --config parameters/pyg-workflow-test.yaml --dataset cms --data_dir ./tensorflow_datasets/ --prefix MLPF_test_ --nvalid 1 --gpus "" --train --test --make-plots diff --git a/tests/test_torch_and_tf.py b/tests/test_torch_and_tf.py index 56bfc96fc..cc174587f 100644 --- a/tests/test_torch_and_tf.py +++ b/tests/test_torch_and_tf.py @@ -1,4 +1,5 @@ import unittest + import numpy as np import torch import tensorflow as tf @@ -9,9 +10,11 @@ class TestGNNTorchAndTensorflow(unittest.TestCase): def test_GHConvDense(self): from mlpf.tfmodel.model import GHConvDense - from mlpf.pyg.model import GHConvDense as GHConvDenseTorch nn1 = GHConvDense(output_dim=128, activation="selu") + + from mlpf.pyg.gnn_lsh import GHConvDense as GHConvDenseTorch + nn2 = GHConvDenseTorch(output_dim=128, activation="selu", hidden_dim=64) x = np.random.normal(size=(2, 4, 64, 64)).astype(np.float32) @@ -39,9 +42,11 @@ def test_GHConvDense(self): def test_MessageBuildingLayerLSH(self): from mlpf.tfmodel.model import MessageBuildingLayerLSH - from mlpf.pyg.model import MessageBuildingLayerLSH as MessageBuildingLayerLSHTorch nn1 = MessageBuildingLayerLSH(distance_dim=128, bin_size=64) + + from mlpf.pyg.gnn_lsh import MessageBuildingLayerLSH as MessageBuildingLayerLSHTorch + nn2 = MessageBuildingLayerLSHTorch(distance_dim=128, bin_size=64) x_dist = np.random.normal(size=(2, 256, 128)).astype(np.float32) @@ -71,7 +76,7 @@ def test_MessageBuildingLayerLSH(self): ret = reverse_lsh(bins_split, x, False) self.assertTrue(np.all(x_node == ret.numpy())) - from mlpf.pyg.model import reverse_lsh as reverse_lsh_torch + from mlpf.pyg.gnn_lsh import reverse_lsh as reverse_lsh_torch bins_split, x, dm, msk_f = out2 ret = reverse_lsh_torch(bins_split, x)