Skip to content

Commit

Permalink
fix rotary_seq_len bug
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jul 12, 2024
1 parent 6d83c33 commit a104a42
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion xtuner/model/modules/dispatch/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def mistral_varlen_attn_forward(
value_states = value_states.transpose(1, 2)
# Because the input can be padded, the absolute sequence length
# depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
Expand Down
2 changes: 1 addition & 1 deletion xtuner/model/modules/dispatch/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def phi3_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(
value_states, position_ids, seq_len=rotary_seq_len)

Expand Down
2 changes: 1 addition & 1 deletion xtuner/model/modules/dispatch/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def qwen2_varlen_attn_forward(
self.layer_idx)

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
Expand Down

0 comments on commit a104a42

Please sign in to comment.