From 16e2f8f1abc272f2b31ec900f6c94731e20a49d6 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:10:37 +0800 Subject: [PATCH] [Bugs] Fix dispatch attn bug (#829) * fix collate bug * fix dispatch attn bugs * fix rotary_seq_len bug --- xtuner/dataset/collate_fns/default_collate_fn.py | 4 ++-- xtuner/model/modules/dispatch/cohere.py | 7 ++++++- xtuner/model/modules/dispatch/deepseek_v2.py | 5 +++++ xtuner/model/modules/dispatch/internlm2.py | 5 +++++ xtuner/model/modules/dispatch/mistral.py | 8 +++++++- xtuner/model/modules/dispatch/phi3.py | 8 +++++++- xtuner/model/modules/dispatch/qwen2.py | 8 +++++++- 7 files changed, 39 insertions(+), 6 deletions(-) diff --git a/xtuner/dataset/collate_fns/default_collate_fn.py b/xtuner/dataset/collate_fns/default_collate_fn.py index f4d5f9197..0ca9264f0 100644 --- a/xtuner/dataset/collate_fns/default_collate_fn.py +++ b/xtuner/dataset/collate_fns/default_collate_fn.py @@ -56,8 +56,8 @@ def default_collate_fn(instances: Sequence[Dict], # Some tokenizers have the same eos token and pad token, so input_ids # cannot be masked directly based on the pad token id. attention_mask = torch.zeros_like(input_ids).bool() - for i in ori_length: - attention_mask[:i] = True + for i, length in enumerate(ori_length): + attention_mask[i, :length] = True bs, seq_len = input_ids.shape position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) diff --git a/xtuner/model/modules/dispatch/cohere.py b/xtuner/model/modules/dispatch/cohere.py index edeb771e3..d3529f570 100644 --- a/xtuner/model/modules/dispatch/cohere.py +++ b/xtuner/model/modules/dispatch/cohere.py @@ -105,17 +105,22 @@ def cohere_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) + # self.num_heads is used in self._upad_input method + # num_heads has been changed because of sequence parallel + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, - q_len, + query_states.shape[1], dropout=dropout_rate) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/xtuner/model/modules/dispatch/deepseek_v2.py b/xtuner/model/modules/dispatch/deepseek_v2.py index dcdb677a3..667d2227c 100644 --- a/xtuner/model/modules/dispatch/deepseek_v2.py +++ b/xtuner/model/modules/dispatch/deepseek_v2.py @@ -128,6 +128,10 @@ def deepseek_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) + # self.num_heads is used in self._upad_input method + # num_heads has been changed because of sequence parallel + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] attn_output = self._flash_attention_forward( query_states, @@ -141,6 +145,7 @@ def deepseek_attn_forward( if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, :self.v_head_dim] diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index 5b855d4ab..7c601f0dc 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -149,6 +149,10 @@ def internlm2_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) + # self.num_heads is used in self._upad_input method + # num_heads has been changed because of sequence parallel + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] dropout_rate = 0.0 attn_output = self._flash_attention_forward( @@ -161,6 +165,7 @@ def internlm2_attn_forward( if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) diff --git a/xtuner/model/modules/dispatch/mistral.py b/xtuner/model/modules/dispatch/mistral.py index 49dfdc108..d08b0f00e 100644 --- a/xtuner/model/modules/dispatch/mistral.py +++ b/xtuner/model/modules/dispatch/mistral.py @@ -214,6 +214,11 @@ def mistral_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) + # num_heads has been changed because of sequence parallel + # `self.num_heads`` is not used in self._flash_attention_forward + # in mistral/mixtral, we are doing this to avoid some unnecessary risk + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] attn_output = self._flash_attention_forward( query_states, @@ -227,6 +232,7 @@ def mistral_attn_forward( if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -311,7 +317,7 @@ def mistral_varlen_attn_forward( value_states = value_states.transpose(1, 2) # Because the input can be padded, the absolute sequence length # depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1) + rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids) diff --git a/xtuner/model/modules/dispatch/phi3.py b/xtuner/model/modules/dispatch/phi3.py index 4003c9d62..97ebc8d33 100644 --- a/xtuner/model/modules/dispatch/phi3.py +++ b/xtuner/model/modules/dispatch/phi3.py @@ -233,6 +233,11 @@ def phi3_attn_forward( pre_process_for_sequence_parallel_attn( query_states, key_states, value_states, scatter_dim=2, gather_dim=1) + # num_heads has been changed because of sequence parallel + # `self.num_heads`` is not used in self._flash_attention_forward + # in mistral/mixtral, we are doing this to avoid some unnecessary risk + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] attn_output = self._flash_attention_forward( query_states, @@ -248,6 +253,7 @@ def phi3_attn_forward( # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim) attn_output = post_process_for_sequence_parallel_attn( attn_output, scatter_dim=1, gather_dim=2) + self.num_heads = ori_num_head attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -333,7 +339,7 @@ def phi3_varlen_attn_forward( self.layer_idx) assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) cos, sin = self.rotary_emb( value_states, position_ids, seq_len=rotary_seq_len) diff --git a/xtuner/model/modules/dispatch/qwen2.py b/xtuner/model/modules/dispatch/qwen2.py index d89dbf947..1c8c5a8d0 100644 --- a/xtuner/model/modules/dispatch/qwen2.py +++ b/xtuner/model/modules/dispatch/qwen2.py @@ -151,6 +151,11 @@ def qwen2_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) + # num_heads has been changed because of sequence parallel + # `self.num_heads`` is not used in self._flash_attention_forward + # in mistral/mixtral, we are doing this to avoid some unnecessary risk + ori_num_head = self.num_heads + self.num_heads = query_states.shape[-2] attn_output = self._flash_attention_forward( query_states, @@ -164,6 +169,7 @@ def qwen2_attn_forward( if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) + self.num_heads = ori_num_head attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -227,7 +233,7 @@ def qwen2_varlen_attn_forward( self.layer_idx) assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1) + rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states,