Skip to content

Commit

Permalink
#14848: Fix lt exclusion logic. Only apply full text mask if in decod…
Browse files Browse the repository at this point in the history
…e mode.
  • Loading branch information
cglagovichTT committed Nov 7, 2024
1 parent 74c4dea commit 865311b
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 24 deletions.
2 changes: 1 addition & 1 deletion models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(stdscr):
for c in commands
for m in models
for d in devices
if not (m in ["11b", "11b-b"] and d == "n150") or not (m == "70b" and d in ["n150", "n300"])
if not ((m in ["11b", "11b-b"] and d == "n150") or (m == "70b" and d in ["n150", "n300"]))
]

# Create output entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,7 @@ def test_llama_cross_attention_transformer_text_inference(
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
)
if mode == "decode":
tt_full_text_mask_expand_11SD = ttnn.reshape(
tt_full_text_mask_expand_11SD,
shape=ttnn.Shape(
[batch, 1, seq_len, head_dim],
[batch, 1, 32, head_dim],
),
)
tt_full_text_mask_expand_11SD = None

tt_out = tt_model(
tt_h,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,7 @@ def test_llama_cross_attention_transformer_block_inference(
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
)
if mode == "decode":
tt_full_text_mask_expand_11SD = ttnn.reshape(
tt_full_text_mask_expand_11SD,
shape=ttnn.Shape(
[batch, 1, seq_len, head_dim],
[batch, 1, 32, head_dim],
),
)
tt_full_text_mask_expand_11SD = None

pt_out = reference_model.forward(
pt_x, xattn_mask=xattn_mask, full_text_row_masked_out_mask=full_text_mask, xattn_cache=pt_xattn_cache
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/multimodal/llama_cross_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def forward(
full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH,
mode=mode,
)

attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn))

res = ttnn.add(x_11SH, attn_out)
mlp_out = self.feed_forward(self.ffn_norm(res, mode=mode), mode=mode)
mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD)
if mode == "prefill":
mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD)
mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffwd))
out = ttnn.add(res, mlp_out)
return out
8 changes: 1 addition & 7 deletions models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,7 @@ def forward(
),
)

tt_full_text_mask_expand_11SD = ttnn.reshape(
tt_full_text_mask_expand_11SD,
shape=ttnn.Shape(
[batch, 1, seq_len, self.configuration.head_dim],
[batch, 1, 32, self.configuration.head_dim],
),
)
tt_full_text_mask_expand_11SD = None

logits = self.text_model.forward(
tt_h,
Expand Down

0 comments on commit 865311b

Please sign in to comment.