Skip to content

Commit

Permalink
#0: Add optional page table input to llama trace functions to enable …
Browse files Browse the repository at this point in the history
…tracing in vLLM

Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed Oct 9, 2024
1 parent 27ba1f1 commit 814e8cc
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 71 deletions.
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def run_decode(
# capture trace
if trace_mode:
logger.info("Capturing trace")
trace_id, tt_inp_emb, rot_mat, cache_idxs_tt, tt_logits = model.capture_trace(
trace_id, tt_inp_emb, rot_mat, cache_idxs_tt, tt_logits, _ = model.capture_trace(
tokens[:, prev_pos:min_prompt_len], prev_pos
)

Expand Down
10 changes: 1 addition & 9 deletions models/demos/t3000/llama2_70b/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,7 @@ def run_test_LlamaModel_inference(
if device_perf:
signpost(DEVICE_PERF_START_SIGNPOST) # start for device perf measurement
# TT hardware execution -------------------------------------------------------------
tt_inp_emb, start_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tt_inp_ids, start_pos, mode=mode)

# Send to device
if mode == "decode":
tt_inp_emb = ttnn.to_device(tt_inp_emb, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = tt_model.tt_embd(tt_inp_emb)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(rot_mat, t3k_mesh_device, memory_config=model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs = ttnn.to_device(cache_idxs, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb, start_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(tt_inp_ids, start_pos, mode=mode)

tt_out = tt_model(
tt_inp_emb,
Expand Down
8 changes: 6 additions & 2 deletions models/demos/t3000/llama2_70b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def run_test_LlamaModel_end_to_end(
if user_id == 0 or user_id == 25:
profiler.start(f"processing_of_prefill_input_{user_id}")

tt_inp_emb, start_pos, rot_mat = tt_model.prepare_inputs(prefill_ids[user_id : user_id + 1], start_pos=0)
tt_inp_emb, start_pos, rot_mat, _, _ = tt_model.prepare_device_inputs(
prefill_ids[user_id : user_id + 1], start_pos=0, mode="prefill"
)
if user_id == 0 or user_id == 25:
profiler.end(f"processing_of_prefill_input_{user_id}")
profiler.start(f"model_run_for_prefill_{user_id}")
Expand Down Expand Up @@ -202,7 +204,9 @@ def run_test_LlamaModel_end_to_end(
if cur_pos == 0 or cur_pos == 35: # Skip the first few iterations to warm up
profiler.start(f"processing_of_decode_input_{cur_pos}")

tt_inp_emb, start_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(decode_ids, start_pos)
tt_inp_emb, start_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(
decode_ids, start_pos, mode="decode"
)

tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = tt_model.tt_embd(tt_inp_emb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,7 @@ def run_test_LlamaModel_end_to_end(

##### Prepare Inputs #####
prev_pos = total_len - 1
tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens, prev_pos)
tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = tt_model.tt_embd(tt_inp_emb)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])

rot_mat = ttnn.to_device(rot_mat, mesh_device, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs = ttnn.to_device(cache_idxs, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb, prev_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(tokens, prev_pos)

##### Compile Model #####
logger.info("Compiling model")
Expand Down
9 changes: 3 additions & 6 deletions models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,9 @@ def run_test_LlamaModel_stress_test(
start_pos = 0
prev_pos = start_pos
for cur_pos in tqdm(range(start_pos + 1, total_len), desc="Decode to 2k Progress", leave=False, colour="green"):
tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens[:, prev_pos:cur_pos], prev_pos)
tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = tt_model.tt_embd(tt_inp_emb)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(rot_mat, mesh_device, memory_config=model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs = ttnn.to_device(cache_idxs, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb, prev_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(
tokens[:, prev_pos:cur_pos], prev_pos
)

tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs)

Expand Down
93 changes: 49 additions & 44 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,35 +117,58 @@ def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cach
tokens, start_pos, page_table=page_table, kv_cache=kv_cache, prompt_lens=prompt_lens
)

def capture_trace(self, tokens: torch.Tensor, start_pos: int):
tt_inp, start_pos, rot_mat, cache_idxs_tt = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode")
def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None):
# Get inputs on device
tt_inp, tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs(
tokens, start_pos, mode="decode", page_table=page_table, return_tokens=True
)

# Compile model
tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = self.tt_model.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode")
tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
start_pos,
cache_idxs=cache_idxs_tt,
page_table=tt_page_table,
kv_cache=kv_cache,
mode="decode",
)

# Capture trace
trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)

# Run TT model
tt_inp_emb = self.tt_model.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode")
tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
start_pos,
cache_idxs=cache_idxs_tt,
page_table=tt_page_table,
kv_cache=kv_cache,
mode="decode",
)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")

return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits
return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits, tt_page_table

def delete_trace(self, trace_id):
ttnn.release_trace(self.mesh_device, trace_id)

def decode_forward_trace(
self, tokens: torch.Tensor, start_pos: int, trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits
self,
tokens: torch.Tensor,
start_pos: int,
trace_id,
tt_inp,
rot_mat,
cache_idxs_tt,
tt_logits,
page_table=None,
tt_page_table=None,
):
batch = tokens.shape[0]

Expand All @@ -155,10 +178,13 @@ def decode_forward_trace(
start_pos,
updated_rot_mat,
updated_cache_idxs_tt,
) = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode")
updated_tt_page_table,
) = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode", page_table=page_table)
ttnn.copy_host_to_device_tensor(updated_tt_inp, tt_inp)
ttnn.copy_host_to_device_tensor(updated_rot_mat, rot_mat)
ttnn.copy_host_to_device_tensor(updated_cache_idxs_tt, cache_idxs_tt)
if page_table is not None:
ttnn.copy_host_to_device_tensor(updated_tt_page_table, tt_page_table)

# Run TT model
ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)
Expand All @@ -173,29 +199,18 @@ def decode_forward_trace(

def decode_forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None):
batch = tokens.shape[0]
tt_inp, start_pos, rot_mat, cache_idxs_tt = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode")
tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = self.tt_model.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)

if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

# Get inputs on device
tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs(
tokens, start_pos, mode="decode", page_table=page_table
)

tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
start_pos,
cache_idxs=cache_idxs_tt,
page_table=page_table,
page_table=tt_page_table,
kv_cache=kv_cache,
mode="decode",
)
Expand All @@ -218,34 +233,24 @@ def prefill_forward_single_user(
assert batch == 1
assert start_pos == 0, "start_pos must be 0 for prefill_forward_single_user"

tt_inp_emb, start_pos, rot_mat, _ = self.tt_model.prepare_inputs(
tokens, start_pos=start_pos, valid_seq_len=seq_len, mode="prefill"
tt_inp_emb, start_pos, rot_mat, _, tt_page_table = self.tt_model.prepare_device_inputs(
tokens, start_pos=start_pos, valid_seq_len=seq_len, mode="prefill", page_table=page_table
)

if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
start_pos,
user_id=user_id,
last_token_idx=last_token_idx,
page_table=page_table,
page_table=tt_page_table,
kv_cache=kv_cache,
mode="prefill",
)

del tt_inp_emb
del rot_mat
del tt_page_table

logits = self._process_logits(tt_logits)
logits = logits.squeeze(1)
Expand Down
56 changes: 54 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def validate_input_shape(self, inp_ids):
seq_len <= self.model_config["MAX_CONTEXT_LEN"]
), f"Sequence length {seq_len} exceeds MAX_CONTEXT_LEN {self.model_config['MAX_CONTEXT_LEN']}"

def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"):
def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", page_table=None):
"""
Prepare inputs for decode mode. Assume that current token is at
start_pos, and KV cache has valid data up to start_pos.
Expand Down Expand Up @@ -240,6 +240,17 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"):

cache_idxs_tt = None # unused in prefill mode

if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
device=self.mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

elif mode == "decode":
assert seq_len == 1, "Decode mode only supports seq_len=1"
xs = x
Expand Down Expand Up @@ -269,7 +280,48 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"):
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

return (xs, start_pos, rot_mats, cache_idxs_tt)
if isinstance(page_table, torch.Tensor):
# Support vLLM tensor page_table input
page_table = ttnn.as_tensor(
page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.mesh_device),
)

return (xs, start_pos, rot_mats, cache_idxs_tt, page_table)

def prepare_device_inputs(
self,
tokens: torch.Tensor,
start_pos: int,
valid_seq_len=None,
mode="decode",
page_table=None,
return_tokens=False, # if true, return tokens for decode mode
):
tt_inp, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.prepare_inputs(
tokens, start_pos, valid_seq_len=valid_seq_len, mode=mode, page_table=page_table
)

if mode == "decode":
tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = self.tt_embd(tt_inp)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
rot_mat = ttnn.to_device(
rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]
)
cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
if tt_page_table is not None:
tt_page_table = ttnn.to_device(tt_page_table, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
else:
tt_inp_emb = tt_inp

return_out = []
if mode == "decode" and return_tokens:
return_out.append(tt_inp)
return_out.extend([tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table])
return tuple(return_out)

def __call__(
self,
Expand Down

0 comments on commit 814e8cc

Please sign in to comment.