From ed5a4d99b6497f8e436c8c8d7214c6dab1cfa49c Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Wed, 11 Dec 2024 14:42:54 +0200 Subject: [PATCH] Use FusedSDPA for MllamaVisionSdpaAttention --- vllm/model_executor/models/mllama.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 57c6bbc7c494d..58fda76e9f18d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -414,11 +414,12 @@ def forward( self.head_dim).transpose(1, 2) # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - dropout_p=0.0) + if current_platform.is_hpu(): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + attn_output = FusedSDPA.apply(q, k, v, attention_mask, 0.0) + else: + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(attn_output.shape[0],