-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.sh
executable file
·84 lines (81 loc) · 3.22 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# ======= train baseline ======= #
# This experiment trains MT3 from scratch
HYDRA_FULL_ERROR=1 OMP_NUM_THREADS=1 python3 train.py \
--config-path="config" \
--config-name="config_slakh_f1_0.65" \
devices=[0,1] \
hydra/job_logging=disabled \
model="MT3Net" \
dataset="Slakh" \
split_frame_length=2000 \
eval.eval_after_num_epoch=400 \
eval.eval_first_n_examples=3 \
eval.eval_per_epoch=10 \
eval.contiguous_inference=False \
# ======= train segmem with prev_frame and context = N ======= #
# This experiment trains MR-MT3 which takes the immediate previous segment
# as memory. The memory block is truncated at length `model_segmem_length`
# which corresponds to `L_agg` in the paper.
HYDRA_FULL_ERROR=1 OMP_NUM_THREADS=1 python3 train.py \
--config-path="config" \
--config-name="config_slakh_segmem" \
devices=[0,1] \
hydra/job_logging=disabled \
model="MT3NetSegMemV2WithPrev" \
dataset="SlakhPrev" \
dataset_use_tf_spectral_ops=False \
dataset_is_randomize_tokens=True \
split_frame_length=2000 \
model_segmem_length=64 \
trainer.check_val_every_n_epoch=20 \
eval.eval_after_num_epoch=400 \
eval.eval_first_n_examples=3 \
eval.eval_per_epoch=10 \
eval.contiguous_inference=True \
# ======= train segmem with prev_frame, prev_augment, context = N ======= #
# This experiment trains MR-MT3 which takes the prior segment as memory.
# This prior segment can be up to N "hops" before the current segment, where
# N = `dataset_prev_augment_frames`, and is written as `L_max_hop` in the paper.
# Similarly, the memory block is truncated at length `model_segmem_length`
# which corresponds to `L_agg` in the paper.
HYDRA_FULL_ERROR=1 OMP_NUM_THREADS=1 python3 train.py \
--config-path="config" \
--config-name="config_slakh_segmem" \
devices=[0,1] \
hydra/job_logging=disabled \
model="MT3NetSegMemV2WithPrev" \
dataset="SlakhPrevAugment" \
dataset_use_tf_spectral_ops=False \
dataset_is_randomize_tokens=True \
split_frame_length=2000 \
model_segmem_length=64 \
dataset_prev_augment_frames=3 \
trainer.check_val_every_n_epoch=20 \
eval.eval_after_num_epoch=400 \
eval.eval_first_n_examples=3 \
eval.eval_per_epoch=10 \
eval.contiguous_inference=True \
# ======= continual training ======= #
# This experiment pre-loads MT3 official checkpoint, and continue training for N epochs
# with the experiment settings proposed above.
# Note that following MT3 official checkpoint, we need to use TF spectral_ops.
HYDRA_FULL_ERROR=1 OMP_NUM_THREADS=1 python3 train.py \
--config-path="config" \
--config-name="config_slakh_segmem_finetune" \
devices=[0,1] \
hydra/job_logging=disabled \
model="MT3NetSegMemV2WithPrevFineTune" \
dataset="SlakhPrevAugment" \
dataset_use_tf_spectral_ops=True \
dataset_is_randomize_tokens=True \
split_frame_length=2000 \
model_segmem_length=64 \
dataset_prev_augment_frames=3 \
trainer.check_val_every_n_epoch=20 \
optim.lr=1e-5 \
num_epochs=100 \
path="../../../pretrained/mt3.pth" \
eval.eval_after_num_epoch=400 \
eval.eval_first_n_examples=3 \
eval.eval_per_epoch=10 \
eval.contiguous_inference=True \