forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.yaml
100 lines (86 loc) · 3.04 KB
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
envs:
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
ARTIFACT_BUCKET_NAME: # TODO: Fill with your unique bucket name, or use --env to pass.
WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass.
MODEL_SIZE: 7
USE_XFORMERS: 1
resources:
accelerators: A100-80GB:8
disk_size: 1024
use_spot: true
num_nodes: 1
file_mounts:
/artifacts:
name: $ARTIFACT_BUCKET_NAME
mode: MOUNT
workdir: .
setup: |
# Download the ShareGPT dataset
# Change to your OWN dataset if you want to train your own model
mkdir -p $HOME/data
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
# Setup the environment
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
cd ./scripts
# Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
# has issues with checkpoint saving -- saving additional large files in the checkpoint folder
pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
if [ $USE_XFORMERS -eq 1 ]; then
pip install -U xformers
fi
python hardcoded_questions.py
python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json
python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
run: |
cd scripts
conda activate chatbot
if [ $USE_XFORMERS -eq 1 ]; then
TRAIN_SCRIPT=train_xformers.py
else
TRAIN_SCRIPT=train.py
fi
PER_DEVICE_BATCH_SIZE=4
SEQ_LEN=2048
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Turn off wandb if no api key is provided
if [ $WANDB_API_KEY == "" ]; then
WANDB_MODE="offline"
fi
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
--data_path $HOME/data/mydata.json \
--bf16 True \
--output_dir /artifacts/chatbot/${MODEL_SIZE}b \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 600 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--run_name $SKYPILOT_TASK_ID \
--gradient_checkpointing True \
--lazy_preprocess True
returncode=$?
exit $returncode