Skip to content

Commit

Permalink
Finetuning (#387)
Browse files Browse the repository at this point in the history
* add --finetune arg

* edit load parent dir for eval

* formatting and linting

* simplify if else loop

* disabling vscode isort and black
  • Loading branch information
farakiko authored Jan 15, 2025
1 parent 0c61f2f commit 210531c
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 28 deletions.
32 changes: 27 additions & 5 deletions mlpf/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,29 @@ def run(rank, world_size, config, outdir, logfile):
checkpoint_dir.mkdir(parents=True, exist_ok=True)

if config["load"]: # load a pre-trained model
with open(f"{outdir}/model_kwargs.pkl", "rb") as f:
model_kwargs = pkl.load(f)

if config["finetune"]:
# outdir is now the new directory for finetuning so must retrieve model_kwargs from the load dir
def get_relevant_directory(path):
# Get the parent directory of the given path
parent_dir = os.path.dirname(path)

# Get the parent of the parent directory
grandparent_dir = os.path.dirname(parent_dir)

# Check if the parent directory is "checkpoints"
if os.path.basename(parent_dir) == "checkpoints":
return grandparent_dir
else:
return parent_dir

with open(f"{get_relevant_directory(config['load'])}/model_kwargs.pkl", "rb") as f:
model_kwargs = pkl.load(f)

else:
with open(f"{outdir}/model_kwargs.pkl", "rb") as f:
model_kwargs = pkl.load(f)

_logger.info("model_kwargs: {}".format(model_kwargs))

if config["conv_type"] == "attention":
Expand All @@ -626,9 +647,10 @@ def run(rank, world_size, config, outdir, logfile):

checkpoint = torch.load(config["load"], map_location=torch.device(rank))

# check if we reached the first epoch in the checkpoint
if "epoch" in checkpoint["extra_state"]:
start_epoch = checkpoint["extra_state"]["epoch"] + 1
if not config["finetune"]: # for --finetune we want to start the count from scratch
# check if we reached the first epoch in the checkpoint
if "epoch" in checkpoint["extra_state"]:
start_epoch = checkpoint["extra_state"]["epoch"] + 1

missing_keys, strict = [], True
for k in model.state_dict().keys():
Expand Down
13 changes: 9 additions & 4 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@
)
parser.add_argument("--test-datasets", nargs="+", default=[], help="test samples to process")

parser.add_argument(
"--finetune",
action="store_true",
default=None,
help="will load and run a training and log the result in the --prefix directory",
)


def get_outdir(resume_training, load):
outdir = None
Expand All @@ -96,10 +103,8 @@ def get_outdir(resume_training, load):
if pload.name == "checkpoint.pth":
# the checkpoint is likely from a Ray Train run and we need to step one dir higher up
outdir = str(pload.parent.parent.parent)
elif pload.name == "best_weights.pth":
outdir = str(pload.parent)
else:
# the checkpoint is likely from a DDP run and we need to step up one dir less
# the checkpoint is likely not from a Ray Train run and we need to step up one dir less
outdir = str(pload.parent.parent)
if not (outdir is None):
assert os.path.isfile("{}/model_kwargs.pkl".format(outdir))
Expand Down Expand Up @@ -158,7 +163,7 @@ def main():
run_hpo(config, args)
else:
outdir = get_outdir(args.resume_training, config["load"])
if outdir is None:
if (outdir is None) or (args.finetune):
outdir = create_experiment_dir(
prefix=(args.prefix or "") + Path(args.config).stem + "_",
experiments_dir=args.experiments_dir if args.experiments_dir else "experiments",
Expand Down
46 changes: 27 additions & 19 deletions parameters/pytorch/pyg-cld.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
backend: pytorch

dataset: cld
train: yes
test: yes
make_plots: yes
comet: yes
save_attention: yes
dataset: clic
sort_data: no
data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
num_epochs: 100
finetune:
num_epochs: 10
patience: 20
lr: 0.0001
lr_schedule: cosinedecay # constant, cosinedecay, onecycle
conv_type: gnn_lsh
conv_type: attention # gnn_lsh, attention, mamba, flashattention
ntrain:
ntest:
nvalid:
Expand All @@ -26,16 +30,16 @@ val_freq: # run an extra validation run every val_freq training steps
model:
trainable: all
learned_representation_mode: last #last, concat
input_encoding: joint #split, joint
pt_mode: linear
input_encoding: split #split, joint
pt_mode: direct-elemtype-split
eta_mode: linear
sin_phi_mode: linear
cos_phi_mode: linear
energy_mode: linear
energy_mode: direct-elemtype-split

gnn_lsh:
conv_type: gnn_lsh
embedding_dim: 256
embedding_dim: 512
width: 512
num_convs: 8
activation: "elu"
Expand All @@ -50,16 +54,17 @@ model:

attention:
conv_type: attention
num_convs: 6
num_convs: 3
dropout_ff: 0.0
dropout_conv_id_mha: 0.0
dropout_conv_id_ff: 0.0
dropout_conv_reg_mha: 0.0
dropout_conv_reg_ff: 0.0
activation: "relu"
head_dim: 16
head_dim: 32
num_heads: 32
attention_type: flash
attention_type: math
use_pre_layernorm: True

mamba:
conv_type: mamba
Expand All @@ -80,8 +85,8 @@ lr_schedule_config:
pct_start: 0.3

raytune:
local_dir: # Note: please specify an absolute path
sched: asha # asha, hyperband
local_dir: # Note: please specify an absolute path
sched: # asha, hyperband
search_alg: # bayes, bohb, hyperopt, nevergrad, scikit
default_metric: "val_loss"
default_mode: "min"
Expand All @@ -100,21 +105,24 @@ raytune:
n_random_steps: 10

train_dataset:
cld:
clic:
physical:
batch_size: 1
samples:
cld_edm_ttbar_pf:
version: 2.0.0
version: 2.5.0
splits: [1,2,3,4,5,6,7,8,9,10]

valid_dataset:
cld:
clic:
physical:
batch_size: 1
samples:
cld_edm_ttbar_pf:
version: 2.0.0
version: 2.5.0
splits: [1,2,3,4,5,6,7,8,9,10]

test_dataset:
cld_edm_ttbar_pf:
version: 2.0.0
version: 2.5.0
splits: [1,2,3,4,5,6,7,8,9,10]
1 change: 1 addition & 0 deletions parameters/pytorch/pyg-clic-hits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 20
patience: 20
lr: 0.001
Expand Down
1 change: 1 addition & 0 deletions parameters/pytorch/pyg-clic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 10
patience: 20
lr: 0.0001
Expand Down
1 change: 1 addition & 0 deletions parameters/pytorch/pyg-cms-nopu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 100
patience: 20
lr: 0.0001
Expand Down
1 change: 1 addition & 0 deletions parameters/pytorch/pyg-cms-ttbar-nopu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 100
patience: 20
lr: 0.0001
Expand Down
1 change: 1 addition & 0 deletions parameters/pytorch/pyg-cms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ data_dir:
gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 5
patience: 20
lr: 0.0001
Expand Down

0 comments on commit 210531c

Please sign in to comment.