diff --git a/models/experimental/mamba/tt_opt/full_model.py b/models/experimental/mamba/tt_opt/full_model.py index 4a100f86d37..d602f2ba546 100644 --- a/models/experimental/mamba/tt_opt/full_model.py +++ b/models/experimental/mamba/tt_opt/full_model.py @@ -11,6 +11,7 @@ from models.experimental.mamba.tt_opt.residual_block import TtResidualBlock + class TtTensorLoader: def __init__(self, state_dict, device, tt_cache_path: str = ""): self.state_dict = state_dict @@ -22,7 +23,7 @@ def load_tt_tensor( name: str, tm_fn: Callable = lambda x: x, postfix: str = "", - device: ttnn.device = self.device, + device: ttnn.Device = self.device, tt_layout=ttnn.TILE_LAYOUT, tt_memory_config=ttnn.DRAM_MEMORY_CONFIG, tt_dtype=ttnn.bfloat16, @@ -36,6 +37,12 @@ def load_tt_tensor( if torch_tensor is None: torch_tensor = self.state_dict[tensor_name] torch_tensor = tm_fn(torch_tensor) + + # Make all loaded tensors rank 4 because there are performance issues with certain + # ops when using with rank 1/2 tensors in ttnn + while len(torch_tensor.size()) < 4: + torch_tensor = torch_tensor.unsqueeze(0) + tt_tensor = ttnn.as_tensor( torch_tensor, device=device, diff --git a/models/experimental/mamba/tt_opt/mamba_one_step_ssm.py b/models/experimental/mamba/tt_opt/mamba_one_step_ssm.py index c992e4e683b..e216c2d1b13 100644 --- a/models/experimental/mamba/tt_opt/mamba_one_step_ssm.py +++ b/models/experimental/mamba/tt_opt/mamba_one_step_ssm.py @@ -6,6 +6,7 @@ import torch.nn.functional as F import ttnn +import tt_lib as ttl from typing import Callable from models.utility_functions import torch2tt_tensor @@ -38,7 +39,9 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): x_proj_weight_name = "mixer.x_proj.weight" # delta_t_proj_weights - self.delta_t_proj_weights = load_fn(x_proj_weight_name, lambda x: x[: self.args.dt_rank, :].transpose(-1, -2), postfix="delta_t") + self.delta_t_proj_weights = load_fn( + x_proj_weight_name, lambda x: x[: self.args.dt_rank, :].transpose(-1, -2), postfix="delta_t" + ) # B_proj_weights def preprocess_B(x): @@ -49,7 +52,8 @@ def preprocess_B(x): self.B_proj_weights = load_fn( x_proj_weight_name, - tm_fn=preprocess_B, postfix="B_proj", + tm_fn=preprocess_B, + postfix="B_proj", ) # C_proj_weights @@ -57,9 +61,8 @@ def preprocess_C(x): x = x[(self.args.dt_rank + self.args.d_state) :, :].transpose(-1, -2) x = F.pad(x, (0, 16), "constant", 0) return x - self.C_proj_weights = load_fn( - x_proj_weight_name, preprocess_C, postfix="C_proj" - ) + + self.C_proj_weights = load_fn(x_proj_weight_name, preprocess_C, postfix="C_proj") # dt_proj_weights dt_proj_weight_name = "mixer.dt_proj.weight" @@ -71,12 +74,13 @@ def preprocess_C(x): # A weight A_weight_name = "mixer.A_log" + def preprocess_A(x): x = -torch.exp(x.float()) # padding with inf x = F.pad(x, (0, 16), "constant", float("-inf")) - x = x.reshape(1, self.hidden_size*32) # (1, 2en) - return x.repeat(self.num_users, 1) # b, 2en + x = x.reshape(1, self.hidden_size * 32) # (1, 2en) + return x.repeat(self.num_users, 1) # b, 2en self.A = load_fn(A_weight_name, tm_fn=preprocess_A, postfix=f"A_{self.args.batch_size}") @@ -89,22 +93,38 @@ def preprocess_A(x): ) # hidden state - prev_hidden_states = torch.zeros((1, 1, self.num_users, self.hidden_size*self.n)) + prev_hidden_states = torch.zeros((1, 1, self.num_users, self.hidden_size * self.n)) self.tt_hidden_state = load_fn(f"tt_hidden_state_{args.batch_size}", torch_tensor=prev_hidden_states) + self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.LoFi, + math_approx_mode=False, + fp32_dest_acc_en=True, + ) + self.core_grid_row = 4 + self.core_grid_col = 8 def forward(self, x): # delta - delta_t_proj_weights = ttnn.to_memory_config(self.delta_t_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG) - delta_t0 = ttnn.linear(x, delta_t_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG) - ttnn.deallocate(delta_t_proj_weights) + delta_t0 = ttnn.linear( + x, + self.delta_t_proj_weights, + memory_config=ttnn.L1_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + use_1d_systolic_array=True, + core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + ) - dt_proj_weights = ttnn.to_memory_config(self.dt_proj_weights, memory_config=self.configs["sharded_rank"]) delta_t1 = ttnn.linear( - delta_t0, self.dt_proj_weights, bias=self.dt_proj_bias, memory_config=ttnn.L1_MEMORY_CONFIG + delta_t0, + self.dt_proj_weights, + bias=self.dt_proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + use_1d_systolic_array=True, + core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), ) ttnn.deallocate(delta_t0) - ttnn.deallocate(dt_proj_weights) delta_t2 = ttnn.softplus(delta_t1, parameter1=1.0, parameter2=20.0, memory_config=ttnn.L1_MEMORY_CONFIG) ttnn.deallocate(delta_t1) @@ -113,56 +133,69 @@ def forward(self, x): delta_t3 = ttnn.repeat_interleave(delta_t2, self.n, dim=3) ttnn.deallocate(delta_t2) - delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_large"]) - abar0 = ttnn.to_memory_config(self.A, memory_config=self.configs["sharded_large"]) - abar1 = ttnn.mul(delta_t4, abar0, memory_config=self.configs["sharded_large"]) + # shard delta and A + delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_dn"]) + abar0 = ttnn.to_memory_config(self.A, memory_config=self.configs["sharded_dn"]) + + abar1 = ttnn.mul(delta_t4, abar0, memory_config=self.configs["sharded_dn"]) ttnn.deallocate(abar0) ttnn.deallocate(delta_t4) abar2 = ttnn.to_memory_config(abar1, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(abar1) # THIS CAUSES A CRASH abar3 = ttnn.exp(abar2, memory_config=ttnn.L1_MEMORY_CONFIG) - abar4 = ttnn.to_memory_config(abar3, memory_config=self.configs["sharded_large"]) + ttnn.deallocate(abar2) + + abar4 = ttnn.to_memory_config(abar3, memory_config=self.configs["sharded_dn"]) + ttnn.deallocate(abar3) # THIS CAUSES A CRASH # multiply abar and hidden_state - hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=self.configs["sharded_large"]) - amulh0 = ttnn.mul(abar4, hidden_state0, memory_config=self.configs["sharded_large"]) + hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=self.configs["sharded_dn"]) + amulh0 = ttnn.mul(abar4, hidden_state0, memory_config=self.configs["sharded_dn"]) # deallocate abar and hidden_state - # ttnn.deallocate(abar1) # THIS CAUSES A CRASH - ttnn.deallocate(abar2) - # ttnn.deallocate(abar3) # THIS CAUSES A CRASH + ttnn.deallocate(abar4) ttnn.deallocate(hidden_state0) # B - B_proj_weights = ttnn.to_memory_config(self.B_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG) - B0 = ttnn.linear(x, B_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG) - ttnn.deallocate(B_proj_weights) + B0 = ttnn.linear( + x, + self.B_proj_weights, + memory_config=ttnn.L1_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + use_1d_systolic_array=True, + core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + ) + B1 = ttnn.repeat(B0, ttnn.Shape([1, 1, 1, self.hidden_size], [1, 1, 32, self.hidden_size])) ttnn.deallocate(B0) - B2 = ttnn.to_memory_config(B1, memory_config=self.configs["sharded_large"]) + + # Shard B + B2 = ttnn.to_memory_config(B1, memory_config=self.configs["sharded_dn"]) ttnn.deallocate(B1) - # bbar - delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_large"]) + # shard delta + delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_dn"]) delta_t3.deallocate() - bbar0 = ttnn.mul(delta_t4, B2, memory_config=self.configs["sharded_large"]) + # bbar + bbar0 = ttnn.mul(delta_t4, B2, memory_config=self.configs["sharded_dn"]) ttnn.deallocate(delta_t4) ttnn.deallocate(B2) # multiply bbar and x x0 = ttnn.repeat_interleave(x, self.n, dim=3) - x1 = ttnn.to_memory_config(x0, memory_config=self.configs["sharded_large"]) + x1 = ttnn.to_memory_config(x0, memory_config=self.configs["sharded_dn"]) ttnn.deallocate(x0) - bmulx0 = ttnn.mul(bbar0, x1, memory_config=self.configs["sharded_large"]) + bmulx0 = ttnn.mul(bbar0, x1, memory_config=self.configs["sharded_dn"]) # deallocate bbar ttnn.deallocate(bbar0) ttnn.deallocate(x1) # add amulh and bmulx - hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=self.configs["sharded_large"]) + hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=self.configs["sharded_dn"]) ttnn.deallocate(self.tt_hidden_state) self.tt_hidden_state = ttnn.to_memory_config(hidden_state1, memory_config=ttnn.DRAM_MEMORY_CONFIG) @@ -171,9 +204,15 @@ def forward(self, x): ttnn.deallocate(bmulx0) # compute C - C_proj = ttnn.to_memory_config(self.C_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG) - C0 = ttnn.linear(x, C_proj, memory_config=ttnn.L1_MEMORY_CONFIG) # b,n - ttnn.deallocate(C_proj) + C0 = ttnn.linear( + x, + self.C_proj_weights, + memory_config=ttnn.L1_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + use_1d_systolic_array=True, + core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + ) # b,n + # ttnn.deallocate(C_proj) C1 = ttnn.permute(C0, (0, 2, 3, 1)) # b,n,1 ttnn.deallocate(C0) @@ -188,13 +227,22 @@ def forward(self, x): C3 = ttnn.permute(C2, (0, 3, 1, 2)) # b, d ttnn.deallocate(C2) + # shard x + # shard C + x = ttnn.to_memory_config(x, memory_config=self.configs["sharded_d"]) + C4 = ttnn.to_memory_config(C3, memory_config=self.configs["sharded_d"]) + ttnn.deallocate(C3) + + # shard D + D = ttnn.to_memory_config(self.D, memory_config=self.configs["sharded_d"]) + # x * D - xD = ttnn.mul(x, self.D, memory_config=ttnn.L1_MEMORY_CONFIG) + xD = ttnn.mul(x, D, memory_config=self.configs["sharded_d"]) ttnn.deallocate(x) # add xD and x - output = ttnn.add(xD, C3, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.add(xD, C4, memory_config=self.configs["sharded_d"]) ttnn.deallocate(xD) - ttnn.deallocate(C3) + ttnn.deallocate(C4) return output diff --git a/models/experimental/mamba/tt_opt/model_config.py b/models/experimental/mamba/tt_opt/model_config.py index f2cd36b8ae8..ef9ac9f44d0 100644 --- a/models/experimental/mamba/tt_opt/model_config.py +++ b/models/experimental/mamba/tt_opt/model_config.py @@ -14,7 +14,7 @@ def create_model_config(num_users, hidden_size): # num_users, hidden_size*2 configs["sharded"] = ttnn.L1_MEMORY_CONFIG - ''' + """ ttnn.create_sharded_memory_config( shape=(1, 1, num_users, hidden_size*2 // (row * col)), core_grid=ttnn.CoreGrid(y=row, x=col), @@ -22,10 +22,10 @@ def create_model_config(num_users, hidden_size): orientation=orientation, use_height_and_width_as_shard_shape=True, ) - ''' + """ configs["sharded_large"] = ttnn.L1_MEMORY_CONFIG - ''' + """ ttnn.create_sharded_memory_config( shape=(1, 1, num_users, hidden_size*2 * latent // (row * col)), core_grid=ttnn.CoreGrid(y=row, x=col), @@ -33,9 +33,9 @@ def create_model_config(num_users, hidden_size): orientation=orientation, use_height_and_width_as_shard_shape=True, ) - ''' + """ configs["sharded_rank"] = ttnn.L1_MEMORY_CONFIG - ''' + """ ttnn.create_sharded_memory_config( shape=(1, 1, hidden_size*2 // 32, hidden_size*2 // (row * col)), core_grid=ttnn.CoreGrid(y=row, x=col), @@ -43,7 +43,17 @@ def create_model_config(num_users, hidden_size): orientation=orientation, use_height_and_width_as_shard_shape=True, ) - ''' + """ + configs["sharded_d"] = ttnn.create_sharded_memory_config( + shape=(1, 1, num_users, hidden_size * 2), + core_grid=ttnn.CoreGrid(y=row, x=col), + strategy=ttnn.ShardStrategy.WIDTH, + ) + configs["sharded_dn"] = ttnn.create_sharded_memory_config( + shape=(1, 1, num_users, hidden_size * 2 * latent), + core_grid=ttnn.CoreGrid(y=row, x=col), + strategy=ttnn.ShardStrategy.WIDTH, + ) return configs