Skip to content

Commit

Permalink
[Bugs] Fix dispatch attn bug (#829)
Browse files Browse the repository at this point in the history
* fix collate bug

* fix dispatch attn bugs

* fix rotary_seq_len bug
  • Loading branch information
HIT-cwh authored Jul 19, 2024
1 parent d58c1dd commit 16e2f8f
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 6 deletions.
4 changes: 2 additions & 2 deletions xtuner/dataset/collate_fns/default_collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def default_collate_fn(instances: Sequence[Dict],
# Some tokenizers have the same eos token and pad token, so input_ids
# cannot be masked directly based on the pad token id.
attention_mask = torch.zeros_like(input_ids).bool()
for i in ori_length:
attention_mask[:i] = True
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True

bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
Expand Down
7 changes: 6 additions & 1 deletion xtuner/model/modules/dispatch/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,22 @@ def cohere_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
query_states.shape[1],
dropout=dropout_rate)

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down
5 changes: 5 additions & 0 deletions xtuner/model/modules/dispatch/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def deepseek_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -141,6 +145,7 @@ def deepseek_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, :self.v_head_dim]
Expand Down
5 changes: 5 additions & 0 deletions xtuner/model/modules/dispatch/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def internlm2_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# self.num_heads is used in self._upad_input method
# num_heads has been changed because of sequence parallel
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

dropout_rate = 0.0
attn_output = self._flash_attention_forward(
Expand All @@ -161,6 +165,7 @@ def internlm2_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def mistral_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -227,6 +232,7 @@ def mistral_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len,
self.hidden_size).contiguous()
Expand Down Expand Up @@ -311,7 +317,7 @@ def mistral_varlen_attn_forward(
value_states = value_states.transpose(1, 2)
# Because the input can be padded, the absolute sequence length
# depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def phi3_attn_forward(
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states,
scatter_dim=2, gather_dim=1)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -248,6 +253,7 @@ def phi3_attn_forward(
# (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
attn_output = post_process_for_sequence_parallel_attn(
attn_output, scatter_dim=1, gather_dim=2)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -333,7 +339,7 @@ def phi3_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(
value_states, position_ids, seq_len=rotary_seq_len)

Expand Down
8 changes: 7 additions & 1 deletion xtuner/model/modules/dispatch/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def qwen2_attn_forward(
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
# num_heads has been changed because of sequence parallel
# `self.num_heads`` is not used in self._flash_attention_forward
# in mistral/mixtral, we are doing this to avoid some unnecessary risk
ori_num_head = self.num_heads
self.num_heads = query_states.shape[-2]

attn_output = self._flash_attention_forward(
query_states,
Expand All @@ -164,6 +169,7 @@ def qwen2_attn_forward(

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
self.num_heads = ori_num_head

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -227,7 +233,7 @@ def qwen2_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down

0 comments on commit 16e2f8f

Please sign in to comment.