From 814e8cc47f69df3848bce9072627e8166173b211 Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Wed, 9 Oct 2024 17:04:19 +0000 Subject: [PATCH] #0: Add optional page table input to llama trace functions to enable tracing in vLLM Signed-off-by: Salar Hosseini --- models/demos/t3000/llama2_70b/demo/demo.py | 2 +- .../llama2_70b/tests/test_llama_model.py | 10 +- .../t3000/llama2_70b/tests/test_llama_perf.py | 8 +- .../tests/test_llama_perf_decode.py | 8 +- .../tests/test_llama_stress_test.py | 9 +- .../t3000/llama2_70b/tt/llama_generation.py | 93 ++++++++++--------- .../llama2_70b/tt/llama_model_optimized.py | 56 ++++++++++- 7 files changed, 115 insertions(+), 71 deletions(-) diff --git a/models/demos/t3000/llama2_70b/demo/demo.py b/models/demos/t3000/llama2_70b/demo/demo.py index 003127fc97f..f90bf59990e 100644 --- a/models/demos/t3000/llama2_70b/demo/demo.py +++ b/models/demos/t3000/llama2_70b/demo/demo.py @@ -261,7 +261,7 @@ def run_decode( # capture trace if trace_mode: logger.info("Capturing trace") - trace_id, tt_inp_emb, rot_mat, cache_idxs_tt, tt_logits = model.capture_trace( + trace_id, tt_inp_emb, rot_mat, cache_idxs_tt, tt_logits, _ = model.capture_trace( tokens[:, prev_pos:min_prompt_len], prev_pos ) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_model.py b/models/demos/t3000/llama2_70b/tests/test_llama_model.py index b53d108a6ca..5ce8f220bbe 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_model.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_model.py @@ -159,15 +159,7 @@ def run_test_LlamaModel_inference( if device_perf: signpost(DEVICE_PERF_START_SIGNPOST) # start for device perf measurement # TT hardware execution ------------------------------------------------------------- - tt_inp_emb, start_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tt_inp_ids, start_pos, mode=mode) - - # Send to device - if mode == "decode": - tt_inp_emb = ttnn.to_device(tt_inp_emb, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_inp_emb = tt_model.tt_embd(tt_inp_emb) - tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, t3k_mesh_device, memory_config=model_config["ROT_MAT_MM_IN1_MEMCFG"]) - cache_idxs = ttnn.to_device(cache_idxs, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb, start_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(tt_inp_ids, start_pos, mode=mode) tt_out = tt_model( tt_inp_emb, diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index de8fab5b8c2..29ca2e75ff8 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -154,7 +154,9 @@ def run_test_LlamaModel_end_to_end( if user_id == 0 or user_id == 25: profiler.start(f"processing_of_prefill_input_{user_id}") - tt_inp_emb, start_pos, rot_mat = tt_model.prepare_inputs(prefill_ids[user_id : user_id + 1], start_pos=0) + tt_inp_emb, start_pos, rot_mat, _, _ = tt_model.prepare_device_inputs( + prefill_ids[user_id : user_id + 1], start_pos=0, mode="prefill" + ) if user_id == 0 or user_id == 25: profiler.end(f"processing_of_prefill_input_{user_id}") profiler.start(f"model_run_for_prefill_{user_id}") @@ -202,7 +204,9 @@ def run_test_LlamaModel_end_to_end( if cur_pos == 0 or cur_pos == 35: # Skip the first few iterations to warm up profiler.start(f"processing_of_decode_input_{cur_pos}") - tt_inp_emb, start_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(decode_ids, start_pos) + tt_inp_emb, start_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs( + decode_ids, start_pos, mode="decode" + ) tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) tt_inp_emb = tt_model.tt_embd(tt_inp_emb) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index fbccd4176c3..dfb3cb8e582 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -127,13 +127,7 @@ def run_test_LlamaModel_end_to_end( ##### Prepare Inputs ##### prev_pos = total_len - 1 - tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens, prev_pos) - tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_inp_emb = tt_model.tt_embd(tt_inp_emb) - tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - - rot_mat = ttnn.to_device(rot_mat, mesh_device, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) - cache_idxs = ttnn.to_device(cache_idxs, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb, prev_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs(tokens, prev_pos) ##### Compile Model ##### logger.info("Compiling model") diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py index a5b9edc7f81..bbfd1e68bd6 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py @@ -98,12 +98,9 @@ def run_test_LlamaModel_stress_test( start_pos = 0 prev_pos = start_pos for cur_pos in tqdm(range(start_pos + 1, total_len), desc="Decode to 2k Progress", leave=False, colour="green"): - tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens[:, prev_pos:cur_pos], prev_pos) - tt_inp_emb = ttnn.to_device(tt_inp_emb, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_inp_emb = tt_model.tt_embd(tt_inp_emb) - tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, mesh_device, memory_config=model_config["ROT_MAT_MM_IN1_MEMCFG"]) - cache_idxs = ttnn.to_device(cache_idxs, mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb, prev_pos, rot_mat, cache_idxs, _ = tt_model.prepare_device_inputs( + tokens[:, prev_pos:cur_pos], prev_pos + ) tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 926cbc23503..4e7fee77697 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -117,16 +117,22 @@ def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cach tokens, start_pos, page_table=page_table, kv_cache=kv_cache, prompt_lens=prompt_lens ) - def capture_trace(self, tokens: torch.Tensor, start_pos: int): - tt_inp, start_pos, rot_mat, cache_idxs_tt = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode") + def capture_trace(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None): + # Get inputs on device + tt_inp, tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs( + tokens, start_pos, mode="decode", page_table=page_table, return_tokens=True + ) # Compile model - tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_inp_emb = self.tt_model.tt_embd(tt_inp) - tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) - cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode") + tt_logits = self.tt_model( + tt_inp_emb, + rot_mat, + start_pos, + cache_idxs=cache_idxs_tt, + page_table=tt_page_table, + kv_cache=kv_cache, + mode="decode", + ) # Capture trace trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) @@ -134,18 +140,35 @@ def capture_trace(self, tokens: torch.Tensor, start_pos: int): # Run TT model tt_inp_emb = self.tt_model.tt_embd(tt_inp) tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - tt_logits = self.tt_model(tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, mode="decode") + tt_logits = self.tt_model( + tt_inp_emb, + rot_mat, + start_pos, + cache_idxs=cache_idxs_tt, + page_table=tt_page_table, + kv_cache=kv_cache, + mode="decode", + ) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") - return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits + return trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits, tt_page_table def delete_trace(self, trace_id): ttnn.release_trace(self.mesh_device, trace_id) def decode_forward_trace( - self, tokens: torch.Tensor, start_pos: int, trace_id, tt_inp, rot_mat, cache_idxs_tt, tt_logits + self, + tokens: torch.Tensor, + start_pos: int, + trace_id, + tt_inp, + rot_mat, + cache_idxs_tt, + tt_logits, + page_table=None, + tt_page_table=None, ): batch = tokens.shape[0] @@ -155,10 +178,13 @@ def decode_forward_trace( start_pos, updated_rot_mat, updated_cache_idxs_tt, - ) = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode") + updated_tt_page_table, + ) = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode", page_table=page_table) ttnn.copy_host_to_device_tensor(updated_tt_inp, tt_inp) ttnn.copy_host_to_device_tensor(updated_rot_mat, rot_mat) ttnn.copy_host_to_device_tensor(updated_cache_idxs_tt, cache_idxs_tt) + if page_table is not None: + ttnn.copy_host_to_device_tensor(updated_tt_page_table, tt_page_table) # Run TT model ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) @@ -173,29 +199,18 @@ def decode_forward_trace( def decode_forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None): batch = tokens.shape[0] - tt_inp, start_pos, rot_mat, cache_idxs_tt = self.tt_model.prepare_inputs(tokens, start_pos, mode="decode") - tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - tt_inp_emb = self.tt_model.tt_embd(tt_inp) - tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) - rot_mat = ttnn.to_device(rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]) - cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - - if isinstance(page_table, torch.Tensor): - # Support vLLM tensor page_table input - page_table = ttnn.as_tensor( - page_table, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), - ) + + # Get inputs on device + tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs( + tokens, start_pos, mode="decode", page_table=page_table + ) + tt_logits = self.tt_model( tt_inp_emb, rot_mat, start_pos, cache_idxs=cache_idxs_tt, - page_table=page_table, + page_table=tt_page_table, kv_cache=kv_cache, mode="decode", ) @@ -218,34 +233,24 @@ def prefill_forward_single_user( assert batch == 1 assert start_pos == 0, "start_pos must be 0 for prefill_forward_single_user" - tt_inp_emb, start_pos, rot_mat, _ = self.tt_model.prepare_inputs( - tokens, start_pos=start_pos, valid_seq_len=seq_len, mode="prefill" + tt_inp_emb, start_pos, rot_mat, _, tt_page_table = self.tt_model.prepare_device_inputs( + tokens, start_pos=start_pos, valid_seq_len=seq_len, mode="prefill", page_table=page_table ) - if isinstance(page_table, torch.Tensor): - # Support vLLM tensor page_table input - page_table = ttnn.as_tensor( - page_table, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), - ) - tt_logits = self.tt_model( tt_inp_emb, rot_mat, start_pos, user_id=user_id, last_token_idx=last_token_idx, - page_table=page_table, + page_table=tt_page_table, kv_cache=kv_cache, mode="prefill", ) del tt_inp_emb del rot_mat + del tt_page_table logits = self._process_logits(tt_logits) logits = logits.squeeze(1) diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 570d9db330d..cc6fcab7920 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -168,7 +168,7 @@ def validate_input_shape(self, inp_ids): seq_len <= self.model_config["MAX_CONTEXT_LEN"] ), f"Sequence length {seq_len} exceeds MAX_CONTEXT_LEN {self.model_config['MAX_CONTEXT_LEN']}" - def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"): + def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", page_table=None): """ Prepare inputs for decode mode. Assume that current token is at start_pos, and KV cache has valid data up to start_pos. @@ -240,6 +240,17 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"): cache_idxs_tt = None # unused in prefill mode + if isinstance(page_table, torch.Tensor): + # Support vLLM tensor page_table input + page_table = ttnn.as_tensor( + page_table, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ) + elif mode == "decode": assert seq_len == 1, "Decode mode only supports seq_len=1" xs = x @@ -269,7 +280,48 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"): mesh_mapper=ReplicateTensorToMesh(self.mesh_device), ) - return (xs, start_pos, rot_mats, cache_idxs_tt) + if isinstance(page_table, torch.Tensor): + # Support vLLM tensor page_table input + page_table = ttnn.as_tensor( + page_table, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ) + + return (xs, start_pos, rot_mats, cache_idxs_tt, page_table) + + def prepare_device_inputs( + self, + tokens: torch.Tensor, + start_pos: int, + valid_seq_len=None, + mode="decode", + page_table=None, + return_tokens=False, # if true, return tokens for decode mode + ): + tt_inp, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.prepare_inputs( + tokens, start_pos, valid_seq_len=valid_seq_len, mode=mode, page_table=page_table + ) + + if mode == "decode": + tt_inp = ttnn.to_device(tt_inp, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb = self.tt_embd(tt_inp) + tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, self.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) + rot_mat = ttnn.to_device( + rot_mat, self.mesh_device, memory_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"] + ) + cache_idxs_tt = ttnn.to_device(cache_idxs_tt, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + if tt_page_table is not None: + tt_page_table = ttnn.to_device(tt_page_table, self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + else: + tt_inp_emb = tt_inp + + return_out = [] + if mode == "decode" and return_tokens: + return_out.append(tt_inp) + return_out.extend([tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table]) + return tuple(return_out) def __call__( self,