diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py index 9209e4c22c9..b86d877d4f2 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py @@ -194,7 +194,6 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, mode) mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) - rot_mats = ttnn.to_device(rot_mats, llama_attention_model.mesh_device) rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_attention_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py index 287426140b2..93376894540 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py @@ -188,7 +188,6 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), device=llama_decoder_model.mesh_device, ) - rot_mats = ttnn.to_device(rot_mats, llama_decoder_model.mesh_device) rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_decoder_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 4e7fee77697..016000ab5f5 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -119,8 +119,21 @@ def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cach def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None): # Get inputs on device - tt_inp, tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs( - tokens, start_pos, mode="decode", page_table=page_table, return_tokens=True + ( + tt_inp_emb, + start_pos, + rot_mat, + cache_idxs_tt, + tt_page_table, + tt_inp, + rot_mat_rm, + ) = self.tt_model.prepare_device_inputs( + tokens, + start_pos, + mode="decode", + page_table=page_table, + return_tokens=True, + return_rot_mat_rm=True, ) # Compile model @@ -140,6 +153,8 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, k # Run TT model tt_inp_emb = self.tt_model.tt_embd(tt_inp) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) tt_logits = self.tt_model( tt_inp_emb, rot_mat, @@ -153,7 +168,7 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, k ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") - return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits, tt_page_table + return trace_id, tt_inp, rot_mat_rm, cache_idxs_tt, tt_logits, tt_page_table def delete_trace(self, trace_id): ttnn.release_trace(self.mesh_device, trace_id) diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index cc6fcab7920..66c058bd503 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -269,7 +269,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", rot_mats = ttnn.as_tensor( rot_mat, dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, + layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ReplicateTensorToMesh(self.mesh_device), ) @@ -299,6 +299,7 @@ def prepare_device_inputs( mode="decode", page_table=None, return_tokens=False, # if true, return tokens for decode mode + return_rot_mat_rm=False, # if true, also return rot_mat in row-major layout for decode ): tt_inp, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.prepare_inputs( tokens, start_pos, valid_seq_len=valid_seq_len, mode=mode, page_table=page_table @@ -308,19 +309,21 @@ def prepare_device_inputs( tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) tt_inp_emb = self.tt_embd(tt_inp) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device( - rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"] - ) + rot_mat_rm = ttnn.to_device(rot_mat, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT) + rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) if tt_page_table is not None: tt_page_table = ttnn.to_device(tt_page_table, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) else: tt_inp_emb = tt_inp - return_out = [] - if mode == "decode" and return_tokens: - return_out.append(tt_inp) - return_out.extend([tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table]) + return_out = [tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table] + if mode == "decode": + if return_tokens: + return_out.append(tt_inp) + if return_rot_mat_rm: + return_out.append(rot_mat_rm) return tuple(return_out) def __call__(