Skip to content

Commit

Permalink
#0: Push rotary embeddings as row-major for e2e speedup, 10 -> 14.3 t…
Browse files Browse the repository at this point in the history
…/s/u Llama3 in demo.

(cherry picked from commit e1251ad)
(cherry picked from commit 2bfdbf5)
  • Loading branch information
cglagovichTT authored and skhorasganiTT committed Oct 9, 2024
1 parent 814e8cc commit d17fd5c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, mode)
mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device),
device=llama_attention_model.mesh_device,
)
rot_mats = ttnn.to_device(rot_mats, llama_attention_model.mesh_device)

rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_attention_model.model_config["ROT_MAT_MM_IN1_MEMCFG"])

Expand Down
1 change: 0 additions & 1 deletion models/demos/t3000/llama2_70b/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode):
mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device),
device=llama_decoder_model.mesh_device,
)
rot_mats = ttnn.to_device(rot_mats, llama_decoder_model.mesh_device)

rot_mats = ttnn.interleaved_to_sharded(rot_mats, llama_decoder_model.model_config["ROT_MAT_MM_IN1_MEMCFG"])

Expand Down
21 changes: 18 additions & 3 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,21 @@ def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cach

def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None):
# Get inputs on device
tt_inp, tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs(
tokens, start_pos, mode="decode", page_table=page_table, return_tokens=True
(
tt_inp_emb,
start_pos,
rot_mat,
cache_idxs_tt,
tt_page_table,
tt_inp,
rot_mat_rm,
) = self.tt_model.prepare_device_inputs(
tokens,
start_pos,
mode="decode",
page_table=page_table,
return_tokens=True,
return_rot_mat_rm=True,
)

# Compile model
Expand All @@ -140,6 +153,8 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, k
# Run TT model
tt_inp_emb = self.tt_model.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT)
rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"])
tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
Expand All @@ -153,7 +168,7 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, k
ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")

return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits, tt_page_table
return trace_id, tt_inp, rot_mat_rm, cache_idxs_tt, tt_logits, tt_page_table

def delete_trace(self, trace_id):
ttnn.release_trace(self.mesh_device, trace_id)
Expand Down
19 changes: 11 additions & 8 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode",
rot_mats = ttnn.as_tensor(
rot_mat,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

Expand Down Expand Up @@ -299,6 +299,7 @@ def prepare_device_inputs(
mode="decode",
page_table=None,
return_tokens=False, # if true, return tokens for decode mode
return_rot_mat_rm=False, # if true, also return rot_mat in row-major layout for decode
):
tt_inp, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.prepare_inputs(
tokens, start_pos, valid_seq_len=valid_seq_len, mode=mode, page_table=page_table
Expand All @@ -308,19 +309,21 @@ def prepare_device_inputs(
tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = self.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(
rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]
)
rot_mat_rm = ttnn.to_device(rot_mat, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
rot_mat = ttnn.to_layout(rot_mat_rm, ttnn.TILE_LAYOUT)
rot_mat = ttnn.interleaved_to_sharded(rot_mat, self.model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
if tt_page_table is not None:
tt_page_table = ttnn.to_device(tt_page_table, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
else:
tt_inp_emb = tt_inp

return_out = []
if mode == "decode" and return_tokens:
return_out.append(tt_inp)
return_out.extend([tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table])
return_out = [tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table]
if mode == "decode":
if return_tokens:
return_out.append(tt_inp)
if return_rot_mat_rm:
return_out.append(rot_mat_rm)
return tuple(return_out)

def __call__(
Expand Down

0 comments on commit d17fd5c

Please sign in to comment.