Skip to content

Commit

Permalink
fix flash attention for mistral. (#758)
Browse files Browse the repository at this point in the history
## fix flash attention for mistral. 

This pull request fixes flash attention forward method for mistral.
  • Loading branch information
divyanshuaggarwal authored Nov 23, 2024
1 parent 0e18a53 commit bf684ad
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions src/adapters/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


if is_flash_attn_2_available():
from transformers.models.mistral.modeling_mistral import _flash_supports_window_size
from transformers.modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -173,18 +173,6 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)

if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory"
" efficient implementation make sure to upgrade flash-attn library."
)

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
Expand Down Expand Up @@ -257,14 +245,17 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

attn_output = self._flash_attention_forward(
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down

0 comments on commit bf684ad

Please sign in to comment.