Skip to content

Commit

Permalink
[tool]GQA convert support (microsoft#454)
Browse files Browse the repository at this point in the history
* [tools]GQA convert support

* fix readme
  • Loading branch information
inkcherry authored Dec 18, 2024
1 parent 676a482 commit c3df187
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 31 deletions.
4 changes: 2 additions & 2 deletions examples_deepspeed/finetune_hf_llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggin

#### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model
```bash
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert_hf2mds
```
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.```convert_mds2hf``` can convert a Megatron-Deepspeed model into the Hugging Face format

#### 2. Fine-tuning Process
```bash
Expand Down
8 changes: 7 additions & 1 deletion examples_deepspeed/finetune_hf_llama/ds_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 1
"steps_per_print": 100,
"zero_optimization": {
"stage": 0
},
"bf16": {
"enabled": true
}
}
5 changes: 5 additions & 0 deletions examples_deepspeed/finetune_hf_llama/ds_config_empty.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 100
}
10 changes: 9 additions & 1 deletion examples_deepspeed/finetune_hf_llama/finetune_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ cat <<EOT > $DS_CONFIG
}
EOT

if [ "$1" = "convert_hf2mds" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
elif [ "$1" = "convert_mds2hf" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
else
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config.json"
fi

covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
Expand All @@ -69,6 +76,7 @@ comm_args="--tensor-model-parallel-size $TP \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--finetune \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--attention-dropout 0 \
--hidden-dropout 0 \
Expand Down Expand Up @@ -97,7 +105,7 @@ comm_args="--tensor-model-parallel-size $TP \
--zero-stage 0 \
--tokenizer-type HFTokenizer \
--tokenizer-model $HF_LLAMA_PATH \
--deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \
--deepspeed_config $DS_CONFIG_PATH \
--deepspeed \
--distributed-backend nccl \
--num-workers 0 \
Expand Down
66 changes: 39 additions & 27 deletions tools/hf2megads_weight_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,28 +193,43 @@ def _qkv_refactor(self, pname, p, hf_layer):
wk = self.hf_model[hf_wk_name]
wv = self.hf_model[hf_wv_name]

hidden_size = wq.shape[0]
per_partition_size, start_index, end_index = compute_partition_range(
hidden_size, self.tp_rank, self.tp_size)
hidden_size_per_attention_head = divide(hidden_size,
query_hidden_size = wq.shape[0]
kv_hidden_size = wk.shape[0]

per_partition_size, start_qindex, end_index = compute_partition_range(
query_hidden_size, self.tp_rank, self.tp_size)
_,start_kvindex, _= compute_partition_range(
kv_hidden_size, self.tp_rank, self.tp_size)

hidden_size_per_attention_head = divide(query_hidden_size,
self.config.num_attention_heads)
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)

new_w = torch.zeros((per_partition_size * 3, wq.shape[1]), dtype=wq.dtype)
num_kv_heads_per_partition= divide(self.config.num_key_value_heads,
self.tp_size)
qkv_size=(num_attention_heads_per_partition+2*num_kv_heads_per_partition)*hidden_size_per_attention_head
num_qheads_per_group=divide(self.config.num_attention_heads,self.config.num_key_value_heads)
num_groups =divide(num_attention_heads_per_partition,num_qheads_per_group)
new_w = torch.zeros((qkv_size, wq.shape[1]), dtype=wq.dtype)

for i in range(num_groups):
query_current_index=start_qindex+i*num_qheads_per_group*hidden_size_per_attention_head
query_next_index=query_current_index+num_qheads_per_group*hidden_size_per_attention_head
kv_current_index=start_kvindex+i*hidden_size_per_attention_head
kv_next_kvindex=kv_current_index+hidden_size_per_attention_head

new_w_index=i* (num_qheads_per_group+2)*hidden_size_per_attention_head

for i in range(num_attention_heads_per_partition):
current_index = start_index + i * hidden_size_per_attention_head
next_index = current_index + hidden_size_per_attention_head
new_w_index = i * (3 * hidden_size_per_attention_head)
new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \
new_w[new_w_index:new_w_index+(num_qheads_per_group+2)*hidden_size_per_attention_head,:]=\
torch.cat([
wq[current_index: next_index, :],
wk[current_index: next_index, :],
wv[current_index: next_index, :]
], dim=0)
wq[query_current_index:query_next_index,:],
wk[kv_current_index:kv_next_kvindex,:],
wv[kv_current_index:kv_next_kvindex,:]
],dim=0)

self.record_mapping_info(
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{current_index}:{next_index},:] of q,k,v{wq.shape}"
f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{query_current_index}:{query_next_index},:] of q,k,v{wq.shape}"
)
return new_w

Expand Down Expand Up @@ -383,17 +398,18 @@ def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer):
hidden_size = oldshape[-1]
hidden_size_per_attention_head = divide(hidden_size,
self.config.num_attention_heads)
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)
newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size)
# MHA & GQA
group = divide(self.config.num_attention_heads, self.config.num_key_value_heads)
newshape = (self.config.num_key_value_heads, group + 2, hidden_size_per_attention_head, hidden_size)
ds_w_out = ds_w_all_rank.reshape(*newshape)
self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1]))
query_weight, key_weight, value_weight = torch.split(ds_w_out, [group, 1, 1], dim=1)
self.hf_dict[hf_q_name] = copy.deepcopy(query_weight.reshape(-1, hidden_size))
self.hf_dict[hf_k_name] = copy.deepcopy(key_weight.reshape(-1, hidden_size))
self.hf_dict[hf_v_name] = copy.deepcopy(value_weight.reshape(-1, hidden_size))
del query_weight, key_weight, value_weight


def transform_from_megads_to_hf(self):
use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False

for pname, p in self.ds_model.named_parameters():
if pname in [
Expand All @@ -411,11 +427,7 @@ def transform_from_megads_to_hf(self):
subname = mobj.group(2)
hf_layer = layer_num - self.offset_num
if subname in ["self_attention.query_key_value.weight"]:
if not use_gqa:
self._qkv_refactor_to_hf(pname, p, hf_layer)
else:
#TODO(billishyahao): Not impl yet ...
assert False
self._qkv_refactor_to_hf(pname, p, hf_layer)
elif subname in ["mlp.dense_h_to_4h.weight"]:
self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer)
elif subname in [
Expand Down

0 comments on commit c3df187

Please sign in to comment.