-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathpretrain.sh
67 lines (63 loc) · 2.12 KB
/
pretrain.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/bin/bash
set -x
wandb login
export CUDA_DEVICE_MAX_CONNECTIONS=1
export GPUS_PER_NODE=8
export NNODES=4
export BATCH_SIZE=8
export GRADIENT_ACCU_STEPS=1
export MASTER_PORT=29502
export CPUS_PER_TASK=24
export QUOTA=reserved
export DATA_PATH=data/llava/llava_pretrain/blip_laion_cc_sbu_558k_add-path.json
export SAVE_PATH=llava-v1.6-7b_vicuna-7b_pretrain_lcs-558k_ft-mlp-lr-1e-3
export BASE_LR=1e-3
SRUN_ARGS=${SRUN_ARGS:-""}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p Your Partion \
--nodes=$NNODES \
--ntasks-per-node=1 \
--gres=gpu:$GPUS_PER_NODE \
--cpus-per-task=$CPUS_PER_TASK \
--kill-on-bad-exit=1 \
--quotatype=${QUOTA} \
${SRUN_ARGS} \
bash -c 'torchrun --nnodes $NNODES --nproc_per_node $GPUS_PER_NODE --node_rank $SLURM_NODEID --master_addr $(scontrol show hostname $SLURM_NODELIST | head -n1) --master_port ${MASTER_PORT} \
llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path pretrained/vicuna/vicuna-7b-v1.5 \
--version plain \
--data_path ${DATA_PATH} \
--image_folder data \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--unfreeze_mm_vision_tower False \
--image_aspect_ratio anyres \
--mm_vision_select_layer -2 \
--mm_vision_select_feature patch \
--mm_patch_merge_type spatial_unpad \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir checkpoints/${SAVE_PATH} \
--num_train_epochs 1 \
--per_device_train_batch_size ${BATCH_SIZE} \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps ${GRADIENT_ACCU_STEPS} \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate ${BASE_LR} \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
--run_name ${SAVE_PATH}'