From 865311b5193bdbfe56c17b7074804d9a59678768 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 7 Nov 2024 07:07:45 -0800 Subject: [PATCH] #14848: Fix lt exclusion logic. Only apply full text mask if in decode mode. --- models/demos/llama3/lt | 2 +- .../test_llama_cross_attention_transformer_text.py | 8 +------- .../llama3/tests/multimodal/test_llama_cross_block.py | 8 +------- models/demos/llama3/tt/multimodal/llama_cross_block.py | 4 ++-- models/demos/llama3/tt/multimodal/llama_vision_model.py | 8 +------- 5 files changed, 6 insertions(+), 24 deletions(-) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 29d21683dc9..c69b4113a84 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -195,7 +195,7 @@ def main(stdscr): for c in commands for m in models for d in devices - if not (m in ["11b", "11b-b"] and d == "n150") or not (m == "70b" and d in ["n150", "n300"]) + if not ((m in ["11b", "11b-b"] and d == "n150") or (m == "70b" and d in ["n150", "n300"])) ] # Create output entries diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index e9d0386c334..26f30bdef91 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -279,13 +279,7 @@ def test_llama_cross_attention_transformer_text_inference( mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), ) if mode == "decode": - tt_full_text_mask_expand_11SD = ttnn.reshape( - tt_full_text_mask_expand_11SD, - shape=ttnn.Shape( - [batch, 1, seq_len, head_dim], - [batch, 1, 32, head_dim], - ), - ) + tt_full_text_mask_expand_11SD = None tt_out = tt_model( tt_h, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 7a2db392dfd..cdfa5ab19eb 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -194,13 +194,7 @@ def test_llama_cross_attention_transformer_block_inference( mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), ) if mode == "decode": - tt_full_text_mask_expand_11SD = ttnn.reshape( - tt_full_text_mask_expand_11SD, - shape=ttnn.Shape( - [batch, 1, seq_len, head_dim], - [batch, 1, 32, head_dim], - ), - ) + tt_full_text_mask_expand_11SD = None pt_out = reference_model.forward( pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index d4f196c85f1..4f3d1cf394a 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -137,12 +137,12 @@ def forward( full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, mode=mode, ) - attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) res = ttnn.add(x_11SH, attn_out) mlp_out = self.feed_forward(self.ffn_norm(res, mode=mode), mode=mode) - mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD) + if mode == "prefill": + mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD) mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffwd)) out = ttnn.add(res, mlp_out) return out diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index a2b602c78c7..47bff66e6e8 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -414,13 +414,7 @@ def forward( ), ) - tt_full_text_mask_expand_11SD = ttnn.reshape( - tt_full_text_mask_expand_11SD, - shape=ttnn.Shape( - [batch, 1, seq_len, self.configuration.head_dim], - [batch, 1, 32, self.configuration.head_dim], - ), - ) + tt_full_text_mask_expand_11SD = None logits = self.text_model.forward( tt_h,