diff --git a/mlpf/model/training.py b/mlpf/model/training.py index b6412fe19..d632be9dd 100644 --- a/mlpf/model/training.py +++ b/mlpf/model/training.py @@ -614,8 +614,29 @@ def run(rank, world_size, config, outdir, logfile): checkpoint_dir.mkdir(parents=True, exist_ok=True) if config["load"]: # load a pre-trained model - with open(f"{outdir}/model_kwargs.pkl", "rb") as f: - model_kwargs = pkl.load(f) + + if config["finetune"]: + # outdir is now the new directory for finetuning so must retrieve model_kwargs from the load dir + def get_relevant_directory(path): + # Get the parent directory of the given path + parent_dir = os.path.dirname(path) + + # Get the parent of the parent directory + grandparent_dir = os.path.dirname(parent_dir) + + # Check if the parent directory is "checkpoints" + if os.path.basename(parent_dir) == "checkpoints": + return grandparent_dir + else: + return parent_dir + + with open(f"{get_relevant_directory(config['load'])}/model_kwargs.pkl", "rb") as f: + model_kwargs = pkl.load(f) + + else: + with open(f"{outdir}/model_kwargs.pkl", "rb") as f: + model_kwargs = pkl.load(f) + _logger.info("model_kwargs: {}".format(model_kwargs)) if config["conv_type"] == "attention": @@ -626,9 +647,10 @@ def run(rank, world_size, config, outdir, logfile): checkpoint = torch.load(config["load"], map_location=torch.device(rank)) - # check if we reached the first epoch in the checkpoint - if "epoch" in checkpoint["extra_state"]: - start_epoch = checkpoint["extra_state"]["epoch"] + 1 + if not config["finetune"]: # for --finetune we want to start the count from scratch + # check if we reached the first epoch in the checkpoint + if "epoch" in checkpoint["extra_state"]: + start_epoch = checkpoint["extra_state"]["epoch"] + 1 missing_keys, strict = [], True for k in model.state_dict().keys(): diff --git a/mlpf/pipeline.py b/mlpf/pipeline.py index 0f3d2117d..5ab978f1f 100644 --- a/mlpf/pipeline.py +++ b/mlpf/pipeline.py @@ -86,6 +86,13 @@ ) parser.add_argument("--test-datasets", nargs="+", default=[], help="test samples to process") +parser.add_argument( + "--finetune", + action="store_true", + default=None, + help="will load and run a training and log the result in the --prefix directory", +) + def get_outdir(resume_training, load): outdir = None @@ -96,10 +103,8 @@ def get_outdir(resume_training, load): if pload.name == "checkpoint.pth": # the checkpoint is likely from a Ray Train run and we need to step one dir higher up outdir = str(pload.parent.parent.parent) - elif pload.name == "best_weights.pth": - outdir = str(pload.parent) else: - # the checkpoint is likely from a DDP run and we need to step up one dir less + # the checkpoint is likely not from a Ray Train run and we need to step up one dir less outdir = str(pload.parent.parent) if not (outdir is None): assert os.path.isfile("{}/model_kwargs.pkl".format(outdir)) @@ -158,7 +163,7 @@ def main(): run_hpo(config, args) else: outdir = get_outdir(args.resume_training, config["load"]) - if outdir is None: + if (outdir is None) or (args.finetune): outdir = create_experiment_dir( prefix=(args.prefix or "") + Path(args.config).stem + "_", experiments_dir=args.experiments_dir if args.experiments_dir else "experiments", diff --git a/parameters/pytorch/pyg-cld.yaml b/parameters/pytorch/pyg-cld.yaml index 204689385..c1c4196c8 100644 --- a/parameters/pytorch/pyg-cld.yaml +++ b/parameters/pytorch/pyg-cld.yaml @@ -1,16 +1,20 @@ -backend: pytorch - -dataset: cld +train: yes +test: yes +make_plots: yes +comet: yes +save_attention: yes +dataset: clic sort_data: no data_dir: gpus: 1 gpu_batch_multiplier: 1 load: -num_epochs: 100 +finetune: +num_epochs: 10 patience: 20 lr: 0.0001 lr_schedule: cosinedecay # constant, cosinedecay, onecycle -conv_type: gnn_lsh +conv_type: attention # gnn_lsh, attention, mamba, flashattention ntrain: ntest: nvalid: @@ -26,16 +30,16 @@ val_freq: # run an extra validation run every val_freq training steps model: trainable: all learned_representation_mode: last #last, concat - input_encoding: joint #split, joint - pt_mode: linear + input_encoding: split #split, joint + pt_mode: direct-elemtype-split eta_mode: linear sin_phi_mode: linear cos_phi_mode: linear - energy_mode: linear + energy_mode: direct-elemtype-split gnn_lsh: conv_type: gnn_lsh - embedding_dim: 256 + embedding_dim: 512 width: 512 num_convs: 8 activation: "elu" @@ -50,16 +54,17 @@ model: attention: conv_type: attention - num_convs: 6 + num_convs: 3 dropout_ff: 0.0 dropout_conv_id_mha: 0.0 dropout_conv_id_ff: 0.0 dropout_conv_reg_mha: 0.0 dropout_conv_reg_ff: 0.0 activation: "relu" - head_dim: 16 + head_dim: 32 num_heads: 32 - attention_type: flash + attention_type: math + use_pre_layernorm: True mamba: conv_type: mamba @@ -80,8 +85,8 @@ lr_schedule_config: pct_start: 0.3 raytune: - local_dir: # Note: please specify an absolute path - sched: asha # asha, hyperband + local_dir: # Note: please specify an absolute path + sched: # asha, hyperband search_alg: # bayes, bohb, hyperopt, nevergrad, scikit default_metric: "val_loss" default_mode: "min" @@ -100,21 +105,24 @@ raytune: n_random_steps: 10 train_dataset: - cld: + clic: physical: batch_size: 1 samples: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] valid_dataset: - cld: + clic: physical: batch_size: 1 samples: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] test_dataset: cld_edm_ttbar_pf: - version: 2.0.0 + version: 2.5.0 + splits: [1,2,3,4,5,6,7,8,9,10] diff --git a/parameters/pytorch/pyg-clic-hits.yaml b/parameters/pytorch/pyg-clic-hits.yaml index 62b470931..92e79deff 100644 --- a/parameters/pytorch/pyg-clic-hits.yaml +++ b/parameters/pytorch/pyg-clic-hits.yaml @@ -5,6 +5,7 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: +finetune: num_epochs: 20 patience: 20 lr: 0.001 diff --git a/parameters/pytorch/pyg-clic.yaml b/parameters/pytorch/pyg-clic.yaml index 083dd7610..89dc39513 100644 --- a/parameters/pytorch/pyg-clic.yaml +++ b/parameters/pytorch/pyg-clic.yaml @@ -9,6 +9,7 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: +finetune: num_epochs: 10 patience: 20 lr: 0.0001 diff --git a/parameters/pytorch/pyg-cms-nopu.yaml b/parameters/pytorch/pyg-cms-nopu.yaml index 093852907..e3715d349 100644 --- a/parameters/pytorch/pyg-cms-nopu.yaml +++ b/parameters/pytorch/pyg-cms-nopu.yaml @@ -7,6 +7,7 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: +finetune: num_epochs: 100 patience: 20 lr: 0.0001 diff --git a/parameters/pytorch/pyg-cms-ttbar-nopu.yaml b/parameters/pytorch/pyg-cms-ttbar-nopu.yaml index 029281d67..a7eca5fbc 100644 --- a/parameters/pytorch/pyg-cms-ttbar-nopu.yaml +++ b/parameters/pytorch/pyg-cms-ttbar-nopu.yaml @@ -9,6 +9,7 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: +finetune: num_epochs: 100 patience: 20 lr: 0.0001 diff --git a/parameters/pytorch/pyg-cms.yaml b/parameters/pytorch/pyg-cms.yaml index ecc34ae63..4b1b2d18f 100644 --- a/parameters/pytorch/pyg-cms.yaml +++ b/parameters/pytorch/pyg-cms.yaml @@ -9,6 +9,7 @@ data_dir: gpus: 1 gpu_batch_multiplier: 1 load: +finetune: num_epochs: 5 patience: 20 lr: 0.0001