From 5cde377ad53068a03a687e60742e8b2de83a6573 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 13 Nov 2024 12:47:41 -0800 Subject: [PATCH] #15008: Move xattn cache generation to text prefill forward (cherry picked from commit d0f78cbc0fda282c3d997bf10466028efac6aaa0) --- .../multimodal/test_llama_cross_attention.py | 55 ++++++------ ..._llama_cross_attention_transformer_text.py | 68 +++++++-------- .../multimodal/test_llama_cross_block.py | 56 ++++++------ .../tt/multimodal/llama_cross_attention.py | 87 +++++-------------- .../llama_cross_attention_transformer_text.py | 2 + .../llama3/tt/multimodal/llama_cross_block.py | 5 +- .../tt/multimodal/llama_vision_model.py | 12 ++- .../llama3/tt/multimodal/vision_generator.py | 5 +- 8 files changed, 120 insertions(+), 170 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index e4830421311..cc34f091e17 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -93,7 +93,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] - # Iterate over batch # Preallocate K and V caches tt_xattn_cache = [ ttnn.from_torch( @@ -106,34 +105,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset ) for _ in range(2) ] - for b in range(batch): - tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_xattn_tokens[b : b + 1], - force_replicated=True, - ) - tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) - tt_xattn_cache_torch = [ - ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, - n_heads, - vision_seq_len, - head_dim, - ) - for x in tt_xattn_cache - ] - - for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc_required) - - logger.info(comp_allclose(pt, tt)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"compute_xattn_kv_cache Passed!") - else: - logger.warning(f"compute_xattn_kv_cache Failed!") - all_tests_pass = False - - assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" """ Test forward, prefill and decode! @@ -179,6 +150,10 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset if mode == "prefill": outputs = [] for b in range(batch): + tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_xattn_tokens[b : b + 1], + force_replicated=True, + ) tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( tt_x[b : b + 1], force_replicated=True, @@ -206,6 +181,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset xattn_cache=tt_xattn_cache, mode=mode, user_id=b, + vision_tokens=tt_tensor_xattn_tokens, ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) @@ -271,4 +247,25 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing + if mode == "prefill": + tt_xattn_cache_torch = [ + ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( + batch, + n_heads, + vision_seq_len, + head_dim, + ) + for x in tt_xattn_cache + ] + for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): + passing, pcc_message = comp_pcc(pt, tt, pcc_required) + + logger.info(comp_allclose(pt, tt)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info(f"compute_xattn_kv_cache Passed!") + else: + logger.warning(f"compute_xattn_kv_cache Failed!") + all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 211c990dc3a..9a615672460 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -109,46 +109,11 @@ def test_llama_cross_attention_transformer_text_inference( # unstack k/v pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] - # slice out replicated k/v heads pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] # Iterate over batch # Preallocate K and V caches tt_xattn_cache = tt_model.setup_cache(max_batch_size=batch) - for b in range(batch): - tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_vision_tokens[b : b + 1], - force_replicated=True, - ) - - tt_xattn_cache = [ - layer.compute_xattn_kv_cache(tt_tensor_vision_tokens, tt_xattn_cache[layer_num], user_id=b) - for layer_num, layer in enumerate(tt_model.cross_attention_layers) - ] - tt_xattn_cache_torch = [ - ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, - n_heads, - vision_seq_len, - head_dim, - ) - for kv_cache in tt_xattn_cache - for x in kv_cache - ] - - for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) - - logger.info(comp_allclose(pt, tt)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info(f"compute_xattn_kv_cache Passed!") - else: - logger.warning(f"compute_xattn_kv_cache Failed!") - all_tests_pass = False - - assert all_tests_pass # Test forward pass of the model n_iter = 10 @@ -214,6 +179,10 @@ def test_llama_cross_attention_transformer_text_inference( if mode == "prefill": outputs = [] for b in range(batch): + tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_vision_tokens[b : b + 1], + force_replicated=True, + ) tt_h = model_args.prepare_inputs_ttnn_prefill( h[b : b + 1], ) @@ -267,6 +236,7 @@ def test_llama_cross_attention_transformer_text_inference( user_id=b, mode=mode, text_only_inference=TEXT_ONLY, + vision_tokens=tt_tensor_vision_tokens, ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)) @@ -357,5 +327,31 @@ def test_llama_cross_attention_transformer_text_inference( passing, pcc_message = comp_pcc(logits, tt_out, pcc_required) logger.info(comp_allclose(logits, tt_out)) logger.info(f"PCC: {pcc_message}") - prev_pos = cur_pos assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" + prev_pos = cur_pos + + if mode == "prefill": + tt_xattn_cache_torch = [ + ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( + batch, + n_heads, + vision_seq_len, + head_dim, + ) + for kv_cache in tt_xattn_cache + for x in kv_cache + ] + + for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): + passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required) + + logger.info(comp_allclose(pt, tt)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info(f"compute_xattn_kv_cache Passed!") + else: + logger.warning(f"compute_xattn_kv_cache Failed!") + all_tests_pass = False + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index d977d73e922..1b0013c78ee 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -87,7 +87,6 @@ def test_llama_cross_attention_transformer_block_inference( pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] - # Iterate over batch # Preallocate K and V caches tt_xattn_cache = [ ttnn.from_torch( @@ -100,34 +99,6 @@ def test_llama_cross_attention_transformer_block_inference( ) for _ in range(2) ] - for b in range(batch): - tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( - tt_xattn_tokens[b : b + 1], - force_replicated=True, - ) - tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b) - tt_xattn_cache_torch = [ - ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( - batch, - n_heads, - vision_seq_len, - head_dim, - ) - for x in tt_xattn_cache - ] - - for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): - passing, pcc_message = comp_pcc(pt, tt, pcc_required) - - logger.info(comp_allclose(pt, tt)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"compute_xattn_kv_cache Passed!") - else: - logger.warning(f"compute_xattn_kv_cache Failed!") - all_tests_pass = False - - assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" """ Test forward, prefill and decode! @@ -176,6 +147,10 @@ def test_llama_cross_attention_transformer_block_inference( if mode == "prefill": outputs = [] for b in range(batch): + tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_xattn_tokens[b : b + 1], + force_replicated=True, + ) tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( tt_x[b : b + 1], ) @@ -211,6 +186,7 @@ def test_llama_cross_attention_transformer_block_inference( xattn_cache=tt_xattn_cache, mode=mode, user_id=b, + vision_tokens=tt_tensor_xattn_tokens, ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) @@ -274,4 +250,26 @@ def test_llama_cross_attention_transformer_block_inference( logger.info(f"PCC: {pcc_message}") all_tests_pass = all_tests_pass and passing + if mode == "prefill": + tt_xattn_cache_torch = [ + ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( + batch, + n_heads, + vision_seq_len, + head_dim, + ) + for x in tt_xattn_cache + ] + + for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): + passing, pcc_message = comp_pcc(pt, tt, pcc_required) + + logger.info(comp_allclose(pt, tt)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info(f"compute_xattn_kv_cache Passed!") + else: + logger.warning(f"compute_xattn_kv_cache Failed!") + all_tests_pass = False + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 87730dae903..63f87fbeb73 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -131,7 +131,12 @@ def __init__( eps=self.norm_eps, ) - def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): + def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache): + """ + Uses xattn_tokens to compute K, V. Should be run inside of forward_prefill. + Updates xattn_cache with K, V (TODO: support page table for KV cache) + Returns contiguous K, V of this user in DRAM + """ # Always runs with batch=1 B, seqlen_y = xattn_tokens.shape[1], xattn_tokens.shape[2] assert B == 1, "Batch size must be 1" @@ -179,15 +184,6 @@ def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): num_kv_heads=self.n_local_kv_heads // 2, transpose_k_heads=False, ) - # def create_heads(x): - # x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) - # x = ttnn.reshape(x, [B, seqlen_y, self.n_local_kv_heads, self.head_dim]) - # x = ttnn.transpose(x, 1, 2) - # x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) - # return x - - # xk = create_heads(xk) - # xv = create_heads(xv) xk = self.k_norm(xk, mode="decode") @@ -204,26 +200,7 @@ def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): ttnn.fill_cache(k_cache, k_fill, user_id) ttnn.fill_cache(v_cache, v_fill, user_id) - return xattn_cache - - ### Below is how I would like to implement TMs, but it results in poor PCC - xk = ttnn.to_layout(xk, layout=ttnn.ROW_MAJOR_LAYOUT) - xv = ttnn.to_layout(xv, layout=ttnn.ROW_MAJOR_LAYOUT) - - xk = xk.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - xv = xv.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - - xk = ttnn.transpose(xk, 1, 2) - xv = ttnn.transpose(xv, 1, 2) - - xk = ttnn.to_layout(xk, layout=ttnn.TILE_LAYOUT) - xv = ttnn.to_layout(xv, layout=ttnn.TILE_LAYOUT) - - # PREFERRED METHOD - # xk = xk.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - # xv = xv.reshape(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - # xk, xv = [ttnn.transpose(tensor, 1, 2) for tensor in (xk, xv)] # HANG! - return [xk, xv] + return xk, xv def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache): batch = xattn_cache[0].shape[0] @@ -239,20 +216,6 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, program_config=self.model_config["VISION_XATTN_Q_PROGCFG"](batch), ) - # # Below is how we want to reshape. It results in poor PCC - # # 1, B, D -> B, 1, NH, DH -> B, NH, 1, DH - # xq = ttnn.to_layout(xq, layout=ttnn.ROW_MAJOR_LAYOUT) - # # Tell shape about padding - # xq = ttnn.reshape( - # xq, - # shape=ttnn.Shape( - # [1, 1, batch, xq.shape[-1]], - # [1, 1, xq.shape[-2], xq.shape[-1]], - # ), - # ) - # xq = ttnn.reshape(xq, (1, batch, self.n_local_heads, self.head_dim)) - # xq = ttnn.to_layout(xq, layout=ttnn.TILE_LAYOUT) - xq, _, _ = ttnn.experimental.nlp_create_qkv_heads( xq, xq, num_heads=self.n_local_heads, num_kv_heads=self.n_local_heads // 2, transpose_k_heads=False ) @@ -289,11 +252,6 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, # WARNING: this broadcast is also broken, must broadcast on host output = ttnn.mul(output, full_text_row_masked_out_mask_1NSH) - # This is how we should be reshaping - # output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT) - # output = ttnn.reshape(output, (1, 1, batch, self.n_local_heads * self.head_dim)) - # output = ttnn.to_layout(output, layout=ttnn.TILE_LAYOUT) - output = ttnn.to_layout(output, layout=ttnn.ROW_MAJOR_LAYOUT) output = ttnn.transpose(output, 1, 2) # 1, B, NH, DH -> 1, NH, B, DH output = ttnn.slice(output, (0, 0, 0, 0), (1, self.n_local_heads, batch, self.head_dim)) @@ -310,7 +268,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, ) # All reduce - if self.is_multichip: # TODO use_fused_all_gather_matmul + if self.is_multichip: dense_out_reduced = ttnn.reduce_scatter( output, scatter_dim=3, @@ -322,11 +280,17 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, else: return output - def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id): + def forward_prefill( + self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id, vision_tokens + ): seq_len = x_11SH.shape[-2] # B, S, D assert seq_len % 32 == 0 and seq_len > 0, "Seqlen must be divisible by 32" + # Compute cross attention cache. Return contiguous caches + k_cache_user, v_cache_user = self.compute_xattn_kv_cache(vision_tokens, user_id, xattn_cache) + cache_seq_len = k_cache_user.shape[-2] + if seq_len > 1024: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // 1024, 1024, -1]) @@ -349,13 +313,6 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH xq = self.q_norm(xq, mode="prefill") - k_cache, v_cache = xattn_cache - cache_seq_len = k_cache.shape[-2] - - k_cache_user = ttnn.slice( - k_cache, (user_id, 0, 0, 0), (user_id + 1, k_cache.shape[1], k_cache.shape[2], k_cache.shape[3]) - ) - scores = ttnn.matmul( xq, ttnn.transpose(k_cache_user, -1, -2), @@ -370,9 +327,6 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH scores = ttnn.add(scores, xattn_mask) scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) - v_cache_user = ttnn.slice( - v_cache, (user_id, 0, 0, 0), (user_id + 1, v_cache.shape[1], v_cache.shape[2], v_cache.shape[3]) - ) output = ttnn.matmul( scores, v_cache_user, @@ -413,10 +367,17 @@ def forward_prefill(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH else: return output - def forward(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, mode, user_id=0): + def forward( + self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, mode, user_id=0, vision_tokens=None + ): if mode == "prefill": return self.forward_prefill( - x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id=user_id + x_11SH, + xattn_mask, + full_text_row_masked_out_mask_1NSH, + xattn_cache, + user_id=user_id, + vision_tokens=vision_tokens, ) else: return self.forward_decode(x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache) diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 74392f3a732..0c50b2128b0 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -259,6 +259,7 @@ def forward( page_table=None, # get_last_token=-1, text_only_inference=False, + vision_tokens=None, ): for idx, ( layer, @@ -274,6 +275,7 @@ def forward( full_text_row_masked_out_mask_11SD=full_text_row_masked_out_mask_11SD, mode=mode, user_id=user_id, + vision_tokens=vision_tokens, ) h = layer( h, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 7ef3754faeb..3ba172a7d39 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -114,9 +114,6 @@ def __init__( memory_config=ttnn.DRAM_MEMORY_CONFIG, ) - def compute_xattn_kv_cache(self, xattn_tokens, xattn_cache, user_id): - return self.attention.compute_xattn_kv_cache(xattn_tokens, xattn_cache, user_id) - def forward( self, x_11SH, @@ -127,6 +124,7 @@ def forward( xattn_cache, mode, user_id=0, + vision_tokens=None, ): attn_out = self.attention( x_11SH=self.attention_norm(x_11SH, mode=mode), @@ -135,6 +133,7 @@ def forward( full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, mode=mode, user_id=user_id, + vision_tokens=vision_tokens, ) attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn)) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index f40bd7f593e..2611a43582c 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -175,8 +175,6 @@ def compute_vision_tokens_masks( batch_images: List[List[PIL_Image.Image]], batch_masks: List[List[List[int]]], total_len: int, - xattn_caches, - user_id, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: skip_vision_encoder = False @@ -244,10 +242,6 @@ def compute_vision_tokens_masks( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - xattn_caches = [ - layer.compute_xattn_kv_cache(vision_tokens_tt, xattn_caches[layer_num], user_id=user_id) - for layer_num, layer in enumerate(self.text_model.cross_attention_layers) - ] padded_masks = _pad_masks( # torch.Size([1, 512, 1, 4]) batch_masks, num_chunks, @@ -270,7 +264,7 @@ def compute_vision_tokens_masks( "constant", get_negative_inf_value(torch.float32), ) - return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + return (vision_tokens_tt, cross_attention_masks, full_text_row_masked_out_mask) def validate_inputs(self, tokens, position_ids): batch, seq_len = tokens.shape[:2] @@ -556,6 +550,7 @@ def forward( xattn_caches, # list of ttnn tensors text_only_inference: bool = False, user_id=0, + vision_tokens=None, ) -> torch.Tensor: """ This method takes torch tensors in, returns torch tensors. @@ -595,6 +590,7 @@ def forward( user_id=user_id, mode=mode, text_only_inference=text_only_inference, + vision_tokens=vision_tokens, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) @@ -612,6 +608,7 @@ def ttnn_prefill_forward( rot_mats, transformation_mats, user_id, + vision_tokens, ): """ This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. @@ -627,6 +624,7 @@ def ttnn_prefill_forward( transformation_mats=transformation_mats, user_id=user_id, mode="prefill", + vision_tokens=vision_tokens, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) return tt_out diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py index d0073a8b911..f8e3681216f 100644 --- a/models/demos/llama3/tt/multimodal/vision_generator.py +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -51,12 +51,10 @@ def prefill_forward_single_user( Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) """ B = tokens.shape[0] - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + vision_tokens, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( batch_images=[vision_images], batch_masks=[vision_mask], total_len=total_len, - xattn_caches=xattn_caches, - user_id=user_id, ) ( @@ -81,6 +79,7 @@ def prefill_forward_single_user( rot_mats, transformation_mats, user_id, + vision_tokens, ) logits = self.model.process_output_prefill(tt_logits, B, prefill_len)