diff --git a/mlpf/model/PFDataset.py b/mlpf/model/PFDataset.py index 57802023f..10e00af19 100644 --- a/mlpf/model/PFDataset.py +++ b/mlpf/model/PFDataset.py @@ -1,14 +1,13 @@ +import sys from types import SimpleNamespace +import numpy as np import tensorflow_datasets as tfds import torch import torch.utils.data from mlpf.model.logger import _logger -import numpy as np -import sys - class TFDSDataSource: def __init__(self, ds, sort): @@ -117,7 +116,9 @@ def __init__(self, data_dir, name, split, num_samples=None, sort=False): builder = tfds.builder(name, data_dir=data_dir) except Exception: _logger.error( - "Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format(name, data_dir) + "Could not find dataset {} in {}, please check that you have downloaded the correct version of the dataset".format( + name, data_dir + ) ) sys.exit(1) self.ds = TFDSDataSource(builder.as_data_source(split=split), sort=sort) @@ -156,7 +157,9 @@ def to(self, device, **kwargs): class Collater: def __init__(self, per_particle_keys_to_get, per_event_keys_to_get, **kwargs): super(Collater, self).__init__(**kwargs) - self.per_particle_keys_to_get = per_particle_keys_to_get # these quantities are a variable-length tensor per each event + self.per_particle_keys_to_get = ( + per_particle_keys_to_get # these quantities are a variable-length tensor per each event + ) self.per_event_keys_to_get = per_event_keys_to_get # these quantities are one value (scalar) per event def __call__(self, inputs): @@ -164,7 +167,9 @@ def __call__(self, inputs): # per-particle quantities need to be padded across events of different size for key_to_get in self.per_particle_keys_to_get: - ret[key_to_get] = torch.nn.utils.rnn.pad_sequence([torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True) + ret[key_to_get] = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(inp[key_to_get]).to(torch.float32) for inp in inputs], batch_first=True + ) # per-event quantities can be stacked across events for key_to_get in self.per_event_keys_to_get: @@ -229,12 +234,16 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray): split_configs = config[f"{split}_dataset"][config["dataset"]][type_]["samples"][sample]["splits"] print("split_configs", split_configs) + nevents = None + if not (config[f"n{split}"] is None): + nevents = config[f"n{split}"] // len(split_configs) + for split_config in split_configs: ds = PFDataset( config["data_dir"], f"{sample}/{split_config}:{version}", split, - num_samples=config[f"n{split}"], + num_samples=nevents, sort=config["sort_data"], ).ds @@ -258,7 +267,9 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray): loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, - collate_fn=Collater(["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"]), + collate_fn=Collater( + ["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "genjets", "targetjets"], ["genmet"] + ), sampler=sampler, num_workers=config["num_workers"], prefetch_factor=config["prefetch_factor"], diff --git a/mlpf/model/inference.py b/mlpf/model/inference.py index ab9908653..e4d2d6a4d 100644 --- a/mlpf/model/inference.py +++ b/mlpf/model/inference.py @@ -10,29 +10,30 @@ import tqdm import vector from jet_utils import match_two_jet_collections -from plotting.plot_utils import ( - get_class_names, +from plotting.plot_utils import ( # plot_elements, compute_met_and_ratio, + get_class_names, load_eval_data, - plot_jets, plot_jet_ratio, plot_jet_response_binned, - plot_jet_response_binned_vstarget, plot_jet_response_binned_eta, + plot_jet_response_binned_vstarget, + plot_jets, plot_met, plot_met_ratio, plot_met_response_binned, plot_num_elements, - plot_particles, plot_particle_ratio, - # plot_elements, + plot_particles, ) from .logger import _logger from .utils import unpack_predictions, unpack_target -def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample): +def predict_one_batch( + conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_etacut, jet_match_dr, outpath, dir_name, sample +): # skip prediction if output exists outfile = f"{outpath}/preds{dir_name}/{sample}/pred_{rank}_{i}.parquet" @@ -62,7 +63,8 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m ycand = unpack_target(batch.ycand.to(torch.float32), model) ypred = unpack_predictions(ypred) - genjets_msk = batch.genjets[:, :, 0].cpu() > jet_ptcut + genjets_msk = (batch.genjets[:, :, 0].cpu() > jet_ptcut) & (abs(batch.genjets[:, :, 1]).cpu() < jet_etacut) + genjets = awkward.unflatten(batch.genjets.cpu().to(torch.float64)[genjets_msk], torch.sum(genjets_msk, axis=1)) genjets = vector.awk( awkward.zip( @@ -125,7 +127,15 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m ) awkward.to_parquet( - awkward.Array({"inputs": Xs, "particles": awkvals, "jets": jets_coll, "matched_jets": matched_jets, "genmet": batch.genmet.cpu()}), + awkward.Array( + { + "inputs": Xs, + "particles": awkvals, + "jets": jets_coll, + "matched_jets": matched_jets, + "genmet": batch.genmet.cpu(), + } + ), outfile, ) _logger.info(f"Saved predictions at {outfile}") @@ -136,7 +146,9 @@ def predict_one_batch_args(args): @torch.no_grad() -def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, jet_ptcut=15.0, jet_match_dr=0.1, dir_name=""): +def run_predictions( + world_size, rank, model, loader, sample, outpath, jetdef, jet_ptcut=15.0, jet_etacut=2.5, jet_match_dr=0.1, dir_name="" +): """Runs inference on the given sample and stores the output as .parquet files.""" if world_size > 1: conv_type = model.module.conv_type @@ -153,7 +165,9 @@ def run_predictions(world_size, rank, model, loader, sample, outpath, jetdef, je ti = time.time() for i, batch in iterator: - predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample) + predict_one_batch( + conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_etacut, jet_match_dr, outpath, dir_name, sample + ) _logger.info(f"Time taken to make predictions on device {rank} is: {((time.time() - ti) / 60):.2f} min") diff --git a/mlpf/model/training.py b/mlpf/model/training.py index 102116618..e7e231a3c 100644 --- a/mlpf/model/training.py +++ b/mlpf/model/training.py @@ -1,59 +1,55 @@ +import csv +import glob +import json +import logging import os import os.path as osp import pickle as pkl +import shutil import time +from datetime import datetime from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional -import logging -import shutil -from datetime import datetime -import tqdm -import yaml -import csv -import json -import sklearn -import sklearn.metrics -import numpy as np -import pandas + import matplotlib import matplotlib.pyplot as plt -import glob - -# comet needs to be imported before torch -from comet_ml import OfflineExperiment, Experiment # noqa: F401, isort:skip - +import numpy as np +import pandas +import sklearn +import sklearn.metrics import torch import torch.distributed as dist import torch.multiprocessing as mp +import tqdm +import yaml from torch import Tensor, nn from torch.nn import functional as F from torch.profiler import ProfilerActivity, profile, record_function from torch.utils.tensorboard import SummaryWriter -from mlpf.model.logger import _logger, _configLogger +from mlpf.model.inference import make_plots, run_predictions +from mlpf.model.logger import _configLogger, _logger +from mlpf.model.mlpf import MLPF, set_save_attention +from mlpf.model.PFDataset import Collater, PFDataset, get_interleaved_dataloaders from mlpf.model.utils import ( - unpack_predictions, - unpack_target, + CLASS_LABELS, + ELEM_TYPES_NONZERO, + X_FEATURES, + count_parameters, + get_lr_schedule, get_model_state_dict, load_checkpoint, save_checkpoint, - CLASS_LABELS, - X_FEATURES, - ELEM_TYPES_NONZERO, save_HPs, - get_lr_schedule, - count_parameters, + unpack_predictions, + unpack_target, ) - - -from mlpf.model.inference import make_plots, run_predictions -from mlpf.model.mlpf import set_save_attention -from mlpf.model.mlpf import MLPF -from mlpf.model.PFDataset import Collater, PFDataset, get_interleaved_dataloaders - from mlpf.utils import create_comet_experiment +# comet needs to be imported before torch +from comet_ml import OfflineExperiment, Experiment # noqa: F401, isort:skip + def sliced_wasserstein_loss(y_pred, y_true, num_projections=200): # create normalized random basis vectors @@ -95,7 +91,9 @@ def mlpf_loss(y, ypred, batch): # binary loss for particle / no-particle classification # loss_binary_classification = loss_obj_id(ypred["cls_binary"], (y["cls_id"] != 0).long()).reshape(y["cls_id"].shape) - loss_binary_classification = 10 * torch.nn.functional.cross_entropy(ypred["cls_binary"], (y["cls_id"] != 0).long(), reduction="none") + loss_binary_classification = 10 * torch.nn.functional.cross_entropy( + ypred["cls_binary"], (y["cls_id"] != 0).long(), reduction="none" + ) # compare the particle type, only for cases where there was a true particle loss_pid_classification = loss_obj_id(ypred["cls_id_onehot"], y["cls_id"]).reshape(y["cls_id"].shape) @@ -145,12 +143,12 @@ def mlpf_loss(y, ypred, batch): pred_met = torch.sqrt(torch.sum(pred_px, axis=-2) ** 2 + torch.sum(pred_py, axis=-2) ** 2).detach() loss["MET"] = torch.nn.functional.huber_loss(pred_met.squeeze(dim=-1), batch.genmet).mean() - was_input_pred = torch.concat([torch.softmax(ypred["cls_binary"].transpose(1, 2), axis=-1), ypred["momentum"]], axis=-1) * batch.mask.unsqueeze( - axis=-1 - ) - was_input_true = torch.concat([torch.nn.functional.one_hot((y["cls_id"] != 0).to(torch.long)), y["momentum"]], axis=-1) * batch.mask.unsqueeze( - axis=-1 - ) + was_input_pred = torch.concat( + [torch.softmax(ypred["cls_binary"].transpose(1, 2), axis=-1), ypred["momentum"]], axis=-1 + ) * batch.mask.unsqueeze(axis=-1) + was_input_true = torch.concat( + [torch.nn.functional.one_hot((y["cls_id"] != 0).to(torch.long)), y["momentum"]], axis=-1 + ) * batch.mask.unsqueeze(axis=-1) # standardize Wasserstein loss std = was_input_true[batch.mask].std(axis=0) @@ -196,7 +194,9 @@ class FocalLoss(nn.Module): - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. """ - def __init__(self, alpha: Optional[Tensor] = None, gamma: float = 0.0, reduction: str = "mean", ignore_index: int = -100): + def __init__( + self, alpha: Optional[Tensor] = None, gamma: float = 0.0, reduction: str = "mean", ignore_index: int = -100 + ): """Constructor. Args: alpha (Tensor, optional): Weights for each class. Defaults to None. @@ -380,28 +380,44 @@ def validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch plt.xlabel("particle proba") tensorboard_writer.add_figure("sig_proba_elemtype{}".format(int(xcls)), fig, global_step=epoch) - tensorboard_writer.add_histogram("pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch) + tensorboard_writer.add_histogram( + "pt_target", torch.clamp(batch.ytarget[batch.mask][:, 2], -10, 10), global_step=epoch + ) tensorboard_writer.add_histogram("pt_pred", torch.clamp(ypred_raw[2][batch.mask][:, 0], -10, 10), global_step=epoch) ratio = (ypred_raw[2][batch.mask][:, 0] / batch.ytarget[batch.mask][:, 2])[batch.ytarget[batch.mask][:, 0] != 0] tensorboard_writer.add_histogram("pt_ratio", torch.clamp(ratio, -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch) + tensorboard_writer.add_histogram( + "eta_target", torch.clamp(batch.ytarget[batch.mask][:, 3], -10, 10), global_step=epoch + ) tensorboard_writer.add_histogram("eta_pred", torch.clamp(ypred_raw[2][batch.mask][:, 1], -10, 10), global_step=epoch) ratio = (ypred_raw[2][batch.mask][:, 1] / batch.ytarget[batch.mask][:, 3])[batch.ytarget[batch.mask][:, 0] != 0] tensorboard_writer.add_histogram("eta_ratio", torch.clamp(ratio, -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch) + tensorboard_writer.add_histogram( + "sphi_target", torch.clamp(batch.ytarget[batch.mask][:, 4], -10, 10), global_step=epoch + ) + tensorboard_writer.add_histogram( + "sphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 2], -10, 10), global_step=epoch + ) ratio = (ypred_raw[2][batch.mask][:, 2] / batch.ytarget[batch.mask][:, 4])[batch.ytarget[batch.mask][:, 0] != 0] tensorboard_writer.add_histogram("sphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch) + tensorboard_writer.add_histogram( + "cphi_target", torch.clamp(batch.ytarget[batch.mask][:, 5], -10, 10), global_step=epoch + ) + tensorboard_writer.add_histogram( + "cphi_pred", torch.clamp(ypred_raw[2][batch.mask][:, 3], -10, 10), global_step=epoch + ) ratio = (ypred_raw[2][batch.mask][:, 3] / batch.ytarget[batch.mask][:, 5])[batch.ytarget[batch.mask][:, 0] != 0] tensorboard_writer.add_histogram("cphi_ratio", torch.clamp(ratio, -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch) - tensorboard_writer.add_histogram("energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch) + tensorboard_writer.add_histogram( + "energy_target", torch.clamp(batch.ytarget[batch.mask][:, 6], -10, 10), global_step=epoch + ) + tensorboard_writer.add_histogram( + "energy_pred", torch.clamp(ypred_raw[2][batch.mask][:, 4], -10, 10), global_step=epoch + ) ratio = (ypred_raw[2][batch.mask][:, 4] / batch.ytarget[batch.mask][:, 6])[batch.ytarget[batch.mask][:, 0] != 0] tensorboard_writer.add_histogram("energy_ratio", torch.clamp(ratio, -10, 10), global_step=epoch) @@ -462,7 +478,9 @@ def train_and_valid( if (world_size > 1) and (rank != 0): iterator = enumerate(data_loader) else: - iterator = tqdm.tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch} {train_or_valid} loop on rank={rank}") + iterator = tqdm.tqdm( + enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch} {train_or_valid} loop on rank={rank}" + ) device_type = "cuda" if isinstance(rank, int) else "cpu" @@ -497,17 +515,23 @@ def train_and_valid( if not is_train: cm_X_target += sklearn.metrics.confusion_matrix( - batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), ytarget["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), + ytarget["cls_id"][batch.mask].detach().cpu().numpy(), + labels=range(13), ) cm_X_pred += sklearn.metrics.confusion_matrix( - batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), ypred["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + batch.X[:, :, 0][batch.mask].detach().cpu().numpy(), + ypred["cls_id"][batch.mask].detach().cpu().numpy(), + labels=range(13), ) cm_id += sklearn.metrics.confusion_matrix( - ytarget["cls_id"][batch.mask].detach().cpu().numpy(), ypred["cls_id"][batch.mask].detach().cpu().numpy(), labels=range(13) + ytarget["cls_id"][batch.mask].detach().cpu().numpy(), + ypred["cls_id"][batch.mask].detach().cpu().numpy(), + labels=range(13), ) - # save the events of the first validation batch for quick checks - if (rank == 0 or rank == "cpu") and itrain == 0: - validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch, outdir) + # # save the events of the first validation batch for quick checks + # if (rank == 0 or rank == "cpu") and itrain == 0: + # validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch, outdir) with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"): if is_train: loss = mlpf_loss(ytarget, ypred, batch) @@ -609,13 +633,28 @@ def train_and_valid( if not is_train and comet_experiment: comet_experiment.log_confusion_matrix( - matrix=cm_X_target, title="Element to target", row_label="X", column_label="target", epoch=epoch, file_name="cm_X_target.json" + matrix=cm_X_target, + title="Element to target", + row_label="X", + column_label="target", + epoch=epoch, + file_name="cm_X_target.json", ) comet_experiment.log_confusion_matrix( - matrix=cm_X_pred, title="Element to pred", row_label="X", column_label="pred", epoch=epoch, file_name="cm_X_pred.json" + matrix=cm_X_pred, + title="Element to pred", + row_label="X", + column_label="pred", + epoch=epoch, + file_name="cm_X_pred.json", ) comet_experiment.log_confusion_matrix( - matrix=cm_id, title="Target to pred", row_label="target", column_label="pred", epoch=epoch, file_name="cm_id.json" + matrix=cm_id, + title="Target to pred", + row_label="target", + column_label="pred", + epoch=epoch, + file_name="cm_id.json", ) num_data = torch.tensor(len(data_loader), device=rank) @@ -700,9 +739,31 @@ def train_mlpf( for epoch in range(start_epoch, num_epochs + 1): t0 = time.time() + losses_v = train_and_valid( + rank, + world_size, + outdir, + model, + optimizer, + train_loader=train_loader, + valid_loader=valid_loader, + trainable=trainable, + is_train=False, + lr_schedule=None, + comet_experiment=comet_experiment, + comet_step_freq=comet_step_freq, + epoch=epoch, + dtype=dtype, + tensorboard_writer=tensorboard_writer_valid, + save_attention=save_attention, + ) + t_valid = time.time() + # training step, edit here to profile a specific epoch if epoch == -1: - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True) as prof: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True + ) as prof: with record_function("model_train"): losses_t = train_and_valid( rank, @@ -740,25 +801,25 @@ def train_mlpf( ) t_train = time.time() # epoch time excluding validation - losses_v = train_and_valid( - rank, - world_size, - outdir, - model, - optimizer, - train_loader=train_loader, - valid_loader=valid_loader, - trainable=trainable, - is_train=False, - lr_schedule=None, - comet_experiment=comet_experiment, - comet_step_freq=comet_step_freq, - epoch=epoch, - dtype=dtype, - tensorboard_writer=tensorboard_writer_valid, - save_attention=save_attention, - ) - t_valid = time.time() + # losses_v = train_and_valid( + # rank, + # world_size, + # outdir, + # model, + # optimizer, + # train_loader=train_loader, + # valid_loader=valid_loader, + # trainable=trainable, + # is_train=False, + # lr_schedule=None, + # comet_experiment=comet_experiment, + # comet_step_freq=comet_step_freq, + # epoch=epoch, + # dtype=dtype, + # tensorboard_writer=tensorboard_writer_valid, + # save_attention=save_attention, + # ) + # t_valid = time.time() if comet_experiment: comet_experiment.log_metrics(losses_t, prefix="epoch_train_loss", epoch=epoch) @@ -899,7 +960,10 @@ def run(rank, world_size, config, args, outdir, logfile): start_epoch = 1 if config["load"]: # load a pre-trained model - with open(f"{outdir}/model_kwargs.pkl", "rb") as f: + pload = Path(config["load"]) + + # with open(f"{outdir}/model_kwargs.pkl", "rb") as f: + with open(f"{pload.parent}/model_kwargs.pkl", "rb") as f: model_kwargs = pkl.load(f) _logger.info("model_kwargs: {}".format(model_kwargs)) @@ -962,7 +1026,9 @@ def run(rank, world_size, config, args, outdir, logfile): _logger.info(f"Model directory {outdir}", color="bold") if args.comet: - comet_experiment = create_comet_experiment(config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir) + comet_experiment = create_comet_experiment( + config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir + ) comet_experiment.set_name(f"rank_{rank}_{Path(outdir).name}") comet_experiment.log_parameter("run_id", Path(outdir).name) comet_experiment.log_parameter("world_size", world_size) @@ -1055,7 +1121,9 @@ def run(rank, world_size, config, args, outdir, logfile): test_loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, - collate_fn=Collater(["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "ycand", "genjets", "targetjets"], ["genmet"]), + collate_fn=Collater( + ["X", "ytarget", "ytarget_pt_orig", "ytarget_e_orig", "ycand", "genjets", "targetjets"], ["genmet"] + ), sampler=sampler, num_workers=config["num_workers"], prefetch_factor=config["prefetch_factor"], @@ -1076,7 +1144,7 @@ def run(rank, world_size, config, args, outdir, logfile): jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.4, -1.0) jet_ptcut = 5 - if config["dataset"] == "cms": + elif config["dataset"] == "cms": import fastjet jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4) @@ -1224,7 +1292,9 @@ def train_ray_trial(config, args, outdir=None): loaders = get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray=True) if args.comet: - comet_experiment = create_comet_experiment(config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir) + comet_experiment = create_comet_experiment( + config["comet_name"], comet_offline=config["comet_offline"], outdir=outdir + ) comet_experiment.set_name(f"world_rank_{world_rank}_{Path(outdir).name}") comet_experiment.log_parameter("run_id", Path(outdir).name) comet_experiment.log_parameter("world_size", world_size) @@ -1260,7 +1330,9 @@ def train_ray_trial(config, args, outdir=None): if args.resume_training: model, optimizer = load_checkpoint(checkpoint, model, optimizer) start_epoch = checkpoint["extra_state"]["epoch"] + 1 - lr_schedule = get_lr_schedule(config, optimizer, config["num_epochs"], steps_per_epoch, last_epoch=start_epoch - 1) + lr_schedule = get_lr_schedule( + config, optimizer, config["num_epochs"], steps_per_epoch, last_epoch=start_epoch - 1 + ) else: # start a new training with model weights loaded from a pre-trained model model = load_checkpoint(checkpoint, model) @@ -1375,7 +1447,6 @@ def run_hpo(config, args): import ray from ray import tune from ray.train.torch import TorchTrainer - from raytune.pt_search_space import raytune_num_samples, search_space from raytune.utils import get_raytune_schedule, get_raytune_search_alg @@ -1424,7 +1495,9 @@ def run_hpo(config, args): if tune.Tuner.can_restore(str(expdir)): # resume unfinished HPO run - tuner = tune.Tuner.restore(str(expdir), trainable=trainer, resume_errored=True, restart_errored=False, resume_unfinished=True) + tuner = tune.Tuner.restore( + str(expdir), trainable=trainer, resume_errored=True, restart_errored=False, resume_unfinished=True + ) else: # start new HPO run search_space = {"train_loop_config": search_space} # the ray TorchTrainer only takes a single arg: train_loop_config @@ -1465,4 +1538,6 @@ def run_hpo(config, args): print(result_df.columns) logging.info("Total time of Tuner.fit(): {}".format(end - start)) - logging.info("Best hyperparameters found according to {} were: {}".format(config["raytune"]["default_metric"], best_config)) + logging.info( + "Best hyperparameters found according to {} were: {}".format(config["raytune"]["default_metric"], best_config) + ) diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index ae405bc3f..589889554 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -8,6 +8,7 @@ import logging import os from pathlib import Path + import matplotlib import numpy as np @@ -18,9 +19,15 @@ os.environ["OMP_NUM_THREADS"] = "1" import yaml -from mlpf.model.training import device_agnostic_run, override_config, run_hpo, run_ray_training from utils import create_experiment_dir +from mlpf.model.training import ( + device_agnostic_run, + override_config, + run_hpo, + run_ray_training, +) + parser = argparse.ArgumentParser() # add default=None to all arparse arguments to ensure they do not override @@ -29,10 +36,14 @@ parser.add_argument("--prefix", type=str, default=None, help="prefix appended to result dir name") parser.add_argument("--data-dir", type=str, default=None, help="path to `tensorflow_datasets/`") parser.add_argument("--gpus", type=int, default=None, help="to use CPU set to 0; else e.g., 4") -parser.add_argument("--gpu-batch-multiplier", type=int, default=None, help="Increase batch size per GPU by this constant factor") +parser.add_argument( + "--gpu-batch-multiplier", type=int, default=None, help="Increase batch size per GPU by this constant factor" +) parser.add_argument("--num-workers", type=int, default=None, help="number of processes to load the data") parser.add_argument("--prefetch-factor", type=int, default=None, help="number of samples to fetch & prefetch at every call") -parser.add_argument("--resume-training", type=str, default=None, help="training dir containing the checkpointed training to resume") +parser.add_argument( + "--resume-training", type=str, default=None, help="training dir containing the checkpointed training to resume" +) parser.add_argument("--load", type=str, default=None, help="load checkpoint and start new training from epoch 1") parser.add_argument("--train", action="store_true", default=None, help="initiates a training") @@ -47,7 +58,9 @@ help="which graph layer to use", choices=["attention", "gnn_lsh", "mamba"], ) -parser.add_argument("--num-convs", type=int, default=None, help="number of cross-particle convolution (GNN, attention, Mamba) layers") +parser.add_argument( + "--num-convs", type=int, default=None, help="number of cross-particle convolution (GNN, attention, Mamba) layers" +) parser.add_argument("--make-plots", action="store_true", default=None, help="make plots of the test predictions") parser.add_argument("--ntrain", type=int, default=None, help="training samples to use, if None use entire dataset") parser.add_argument("--ntest", type=int, default=None, help="training samples to use, if None use entire dataset") @@ -152,7 +165,9 @@ def main(): run_hpo(config, args) else: outdir = get_outdir(args.resume_training, config["load"]) - if outdir is None: + + # if outdir is None: + if (args.load and args.train) or (outdir is None): outdir = create_experiment_dir( prefix=(args.prefix or "") + Path(args.config).stem + "_", experiments_dir=args.experiments_dir if args.experiments_dir else "experiments", diff --git a/parameters/pytorch/pyg-cld.yaml b/parameters/pytorch/pyg-cld.yaml index 204689385..76ebe263c 100644 --- a/parameters/pytorch/pyg-cld.yaml +++ b/parameters/pytorch/pyg-cld.yaml @@ -1,16 +1,17 @@ backend: pytorch -dataset: cld +save_attention: yes +dataset: clic sort_data: no data_dir: gpus: 1 gpu_batch_multiplier: 1 load: -num_epochs: 100 +num_epochs: 30 patience: 20 lr: 0.0001 lr_schedule: cosinedecay # constant, cosinedecay, onecycle -conv_type: gnn_lsh +conv_type: attention # gnn_lsh, attention, mamba, flashattention ntrain: ntest: nvalid: @@ -26,16 +27,16 @@ val_freq: # run an extra validation run every val_freq training steps model: trainable: all learned_representation_mode: last #last, concat - input_encoding: joint #split, joint - pt_mode: linear + input_encoding: split #split, joint + pt_mode: direct-elemtype-split eta_mode: linear sin_phi_mode: linear cos_phi_mode: linear - energy_mode: linear + energy_mode: direct-elemtype-split gnn_lsh: conv_type: gnn_lsh - embedding_dim: 256 + embedding_dim: 512 width: 512 num_convs: 8 activation: "elu" @@ -51,15 +52,16 @@ model: attention: conv_type: attention num_convs: 6 - dropout_ff: 0.0 + dropout_ff: 0.1 dropout_conv_id_mha: 0.0 dropout_conv_id_ff: 0.0 - dropout_conv_reg_mha: 0.0 - dropout_conv_reg_ff: 0.0 + dropout_conv_reg_mha: 0.1 + dropout_conv_reg_ff: 0.1 activation: "relu" - head_dim: 16 + head_dim: 32 num_heads: 32 - attention_type: flash + attention_type: math + use_pre_layernorm: True mamba: conv_type: mamba @@ -80,8 +82,8 @@ lr_schedule_config: pct_start: 0.3 raytune: - local_dir: # Note: please specify an absolute path - sched: asha # asha, hyperband + local_dir: # Note: please specify an absolute path + sched: # asha, hyperband search_alg: # bayes, bohb, hyperopt, nevergrad, scikit default_metric: "val_loss" default_mode: "min" @@ -100,21 +102,24 @@ raytune: n_random_steps: 10 train_dataset: - cld: + clic: physical: batch_size: 1 samples: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] valid_dataset: - cld: + clic: physical: batch_size: 1 samples: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] test_dataset: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] \ No newline at end of file