From 361d11b5ff91e3d4c58e1569f166357da01db2a6 Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Thu, 23 May 2024 21:19:20 +0000 Subject: [PATCH] #5383: [Falcon7b] Change pad/unpad in mlp to ttlib to fix multi-chip decode perf regression in demo Signed-off-by: Salar Hosseini --- models/demos/falcon7b/tt/falcon_mlp.py | 31 +++++++++++++++-------- models/demos/t3000/falcon7b/demo_t3000.py | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/models/demos/falcon7b/tt/falcon_mlp.py b/models/demos/falcon7b/tt/falcon_mlp.py index 1d0b5b2eec0..2323c15d94c 100644 --- a/models/demos/falcon7b/tt/falcon_mlp.py +++ b/models/demos/falcon7b/tt/falcon_mlp.py @@ -229,6 +229,7 @@ def __init__( self.state_dict = state_dict self.devices = devices + self.num_devices = len(devices) self.hidden_size = hidden_size self.model_config = model_config self.padding_value = model_config["MLP_PADDING_VALUE"] @@ -274,7 +275,7 @@ def __init__( def _load_mlp_padded_tensors(self): tt_paddings = [] - for device_id in range(len(self.devices)): + for device_id in range(self.num_devices): tt_padding = torch.zeros((1, 1, 32, 64)).bfloat16().float() # 4608 - 4544 = 64, batch=32 tt_padding = ttnn.from_torch( tt_padding, @@ -288,17 +289,19 @@ def _load_mlp_padded_tensors(self): self.model_config["MLP_DECODE_PADDING_TENSORS"] = tt_paddings def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: + batch_size = x[0].shape[-2] # assume all devices have same shape hidden_states = [] - for device_id in range(len(x)): - # pad inputs with padding tensor if not already padded - if ( - self.model_config["PREFILL_OPTIMIZED_MODE"] - and x[device_id].shape[-1] < self.padding_value - and self.prefill_seq_len in [1024, 2048] - ): - x[device_id] = ttnn.concat( + # pad inputs with padding tensor if not already padded + if ( + self.model_config["PREFILL_OPTIMIZED_MODE"] + and self.hidden_size < self.padding_value + and self.prefill_seq_len in [1024, 2048] + ): + for device_id in range(self.num_devices): + x[device_id] = tt_lib.tensor.concat( [x[device_id], self.model_config["MLP_DECODE_PADDING_TENSORS"][device_id]], dim=3 ) + for device_id in range(self.num_devices): hidden_states.append( tt_lib.tensor.falcon_dense_h_to_4h_matmul( x[device_id], @@ -309,7 +312,7 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: ) ) x[device_id].deallocate() - for device_id in range(len(x)): + for device_id in range(self.num_devices): hidden_states[device_id] = tt_lib.tensor.falcon_dense_4h_to_h_matmul( hidden_states[device_id], self.dense_4h_to_h_weights[device_id], @@ -319,7 +322,13 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: ) # remove padding from output if self.model_config["PREFILL_OPTIMIZED_MODE"] and self.prefill_seq_len in [1024, 2048]: - hidden_states = [hidden_states[i][:, :, :, : self.hidden_size] for i in range(len(self.devices))] + for i in range(self.num_devices): + hidden_states[i] = tt_lib.tensor.unpad( + hidden_states[i], + [0, 0, 0, 0], + [0, 0, batch_size - 1, self.hidden_size - 1], + output_mem_config=self.model_config["DENSE_4H_TO_H_MM_OUTPUT_MEMCFG"], + ) # return TT Tensor return hidden_states diff --git a/models/demos/t3000/falcon7b/demo_t3000.py b/models/demos/t3000/falcon7b/demo_t3000.py index 5e17efadcec..7800b4c1921 100644 --- a/models/demos/t3000/falcon7b/demo_t3000.py +++ b/models/demos/t3000/falcon7b/demo_t3000.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize( "perf_mode, expected_perf_prefill_decode, greedy_sampling, expected_greedy_output_path", ( - (True, [5780, 700], False, None), + (True, [6600, 1050], False, None), (True, None, False, None), (False, None, True, "models/demos/t3000/falcon7b/expected_greedy_output.json"), (False, None, True, None),