Skip to content

Commit

Permalink
#5383: [Falcon7b] Change pad/unpad in mlp to ttlib to fix multi-chip …
Browse files Browse the repository at this point in the history
…decode perf regression in demo

Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed May 23, 2024
1 parent 1a90204 commit 361d11b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
31 changes: 20 additions & 11 deletions models/demos/falcon7b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(

self.state_dict = state_dict
self.devices = devices
self.num_devices = len(devices)
self.hidden_size = hidden_size
self.model_config = model_config
self.padding_value = model_config["MLP_PADDING_VALUE"]
Expand Down Expand Up @@ -274,7 +275,7 @@ def __init__(

def _load_mlp_padded_tensors(self):
tt_paddings = []
for device_id in range(len(self.devices)):
for device_id in range(self.num_devices):
tt_padding = torch.zeros((1, 1, 32, 64)).bfloat16().float() # 4608 - 4544 = 64, batch=32
tt_padding = ttnn.from_torch(
tt_padding,
Expand All @@ -288,17 +289,19 @@ def _load_mlp_padded_tensors(self):
self.model_config["MLP_DECODE_PADDING_TENSORS"] = tt_paddings

def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
batch_size = x[0].shape[-2] # assume all devices have same shape
hidden_states = []
for device_id in range(len(x)):
# pad inputs with padding tensor if not already padded
if (
self.model_config["PREFILL_OPTIMIZED_MODE"]
and x[device_id].shape[-1] < self.padding_value
and self.prefill_seq_len in [1024, 2048]
):
x[device_id] = ttnn.concat(
# pad inputs with padding tensor if not already padded
if (
self.model_config["PREFILL_OPTIMIZED_MODE"]
and self.hidden_size < self.padding_value
and self.prefill_seq_len in [1024, 2048]
):
for device_id in range(self.num_devices):
x[device_id] = tt_lib.tensor.concat(
[x[device_id], self.model_config["MLP_DECODE_PADDING_TENSORS"][device_id]], dim=3
)
for device_id in range(self.num_devices):
hidden_states.append(
tt_lib.tensor.falcon_dense_h_to_4h_matmul(
x[device_id],
Expand All @@ -309,7 +312,7 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
)
)
x[device_id].deallocate()
for device_id in range(len(x)):
for device_id in range(self.num_devices):
hidden_states[device_id] = tt_lib.tensor.falcon_dense_4h_to_h_matmul(
hidden_states[device_id],
self.dense_4h_to_h_weights[device_id],
Expand All @@ -319,7 +322,13 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
)
# remove padding from output
if self.model_config["PREFILL_OPTIMIZED_MODE"] and self.prefill_seq_len in [1024, 2048]:
hidden_states = [hidden_states[i][:, :, :, : self.hidden_size] for i in range(len(self.devices))]
for i in range(self.num_devices):
hidden_states[i] = tt_lib.tensor.unpad(
hidden_states[i],
[0, 0, 0, 0],
[0, 0, batch_size - 1, self.hidden_size - 1],
output_mem_config=self.model_config["DENSE_4H_TO_H_MM_OUTPUT_MEMCFG"],
)

# return TT Tensor
return hidden_states
2 changes: 1 addition & 1 deletion models/demos/t3000/falcon7b/demo_t3000.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.mark.parametrize(
"perf_mode, expected_perf_prefill_decode, greedy_sampling, expected_greedy_output_path",
(
(True, [5780, 700], False, None),
(True, [6600, 1050], False, None),
(True, None, False, None),
(False, None, True, "models/demos/t3000/falcon7b/expected_greedy_output.json"),
(False, None, True, None),
Expand Down

0 comments on commit 361d11b

Please sign in to comment.