Skip to content

Commit

Permalink
#7012: Add sharding support to SSM block
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar authored and esmalTT committed Apr 2, 2024
1 parent 6c0a0d1 commit d61fc22
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 46 deletions.
9 changes: 8 additions & 1 deletion models/experimental/mamba/tt_opt/full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from models.experimental.mamba.tt_opt.residual_block import TtResidualBlock


class TtTensorLoader:
def __init__(self, state_dict, device, tt_cache_path: str = ""):
self.state_dict = state_dict
Expand All @@ -22,7 +23,7 @@ def load_tt_tensor(
name: str,
tm_fn: Callable = lambda x: x,
postfix: str = "",
device: ttnn.device = self.device,
device: ttnn.Device = self.device,
tt_layout=ttnn.TILE_LAYOUT,
tt_memory_config=ttnn.DRAM_MEMORY_CONFIG,
tt_dtype=ttnn.bfloat16,
Expand All @@ -36,6 +37,12 @@ def load_tt_tensor(
if torch_tensor is None:
torch_tensor = self.state_dict[tensor_name]
torch_tensor = tm_fn(torch_tensor)

# Make all loaded tensors rank 4 because there are performance issues with certain
# ops when using with rank 1/2 tensors in ttnn
while len(torch_tensor.size()) < 4:
torch_tensor = torch_tensor.unsqueeze(0)

tt_tensor = ttnn.as_tensor(
torch_tensor,
device=device,
Expand Down
126 changes: 87 additions & 39 deletions models/experimental/mamba/tt_opt/mamba_one_step_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F

import ttnn
import tt_lib as ttl
from typing import Callable

from models.utility_functions import torch2tt_tensor
Expand Down Expand Up @@ -38,7 +39,9 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
x_proj_weight_name = "mixer.x_proj.weight"

# delta_t_proj_weights
self.delta_t_proj_weights = load_fn(x_proj_weight_name, lambda x: x[: self.args.dt_rank, :].transpose(-1, -2), postfix="delta_t")
self.delta_t_proj_weights = load_fn(
x_proj_weight_name, lambda x: x[: self.args.dt_rank, :].transpose(-1, -2), postfix="delta_t"
)

# B_proj_weights
def preprocess_B(x):
Expand All @@ -49,17 +52,17 @@ def preprocess_B(x):

self.B_proj_weights = load_fn(
x_proj_weight_name,
tm_fn=preprocess_B, postfix="B_proj",
tm_fn=preprocess_B,
postfix="B_proj",
)

# C_proj_weights
def preprocess_C(x):
x = x[(self.args.dt_rank + self.args.d_state) :, :].transpose(-1, -2)
x = F.pad(x, (0, 16), "constant", 0)
return x
self.C_proj_weights = load_fn(
x_proj_weight_name, preprocess_C, postfix="C_proj"
)

self.C_proj_weights = load_fn(x_proj_weight_name, preprocess_C, postfix="C_proj")

# dt_proj_weights
dt_proj_weight_name = "mixer.dt_proj.weight"
Expand All @@ -71,12 +74,13 @@ def preprocess_C(x):

# A weight
A_weight_name = "mixer.A_log"

def preprocess_A(x):
x = -torch.exp(x.float())
# padding with inf
x = F.pad(x, (0, 16), "constant", float("-inf"))
x = x.reshape(1, self.hidden_size*32) # (1, 2en)
return x.repeat(self.num_users, 1) # b, 2en
x = x.reshape(1, self.hidden_size * 32) # (1, 2en)
return x.repeat(self.num_users, 1) # b, 2en

self.A = load_fn(A_weight_name, tm_fn=preprocess_A, postfix=f"A_{self.args.batch_size}")

Expand All @@ -89,22 +93,38 @@ def preprocess_A(x):
)

# hidden state
prev_hidden_states = torch.zeros((1, 1, self.num_users, self.hidden_size*self.n))
prev_hidden_states = torch.zeros((1, 1, self.num_users, self.hidden_size * self.n))
self.tt_hidden_state = load_fn(f"tt_hidden_state_{args.batch_size}", torch_tensor=prev_hidden_states)

self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.LoFi,
math_approx_mode=False,
fp32_dest_acc_en=True,
)
self.core_grid_row = 4
self.core_grid_col = 8

def forward(self, x):
# delta
delta_t_proj_weights = ttnn.to_memory_config(self.delta_t_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
delta_t0 = ttnn.linear(x, delta_t_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(delta_t_proj_weights)
delta_t0 = ttnn.linear(
x,
self.delta_t_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col),
)

dt_proj_weights = ttnn.to_memory_config(self.dt_proj_weights, memory_config=self.configs["sharded_rank"])
delta_t1 = ttnn.linear(
delta_t0, self.dt_proj_weights, bias=self.dt_proj_bias, memory_config=ttnn.L1_MEMORY_CONFIG
delta_t0,
self.dt_proj_weights,
bias=self.dt_proj_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col),
)
ttnn.deallocate(delta_t0)
ttnn.deallocate(dt_proj_weights)

delta_t2 = ttnn.softplus(delta_t1, parameter1=1.0, parameter2=20.0, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(delta_t1)
Expand All @@ -113,56 +133,69 @@ def forward(self, x):
delta_t3 = ttnn.repeat_interleave(delta_t2, self.n, dim=3)
ttnn.deallocate(delta_t2)

delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_large"])
abar0 = ttnn.to_memory_config(self.A, memory_config=self.configs["sharded_large"])
abar1 = ttnn.mul(delta_t4, abar0, memory_config=self.configs["sharded_large"])
# shard delta and A
delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_dn"])
abar0 = ttnn.to_memory_config(self.A, memory_config=self.configs["sharded_dn"])

abar1 = ttnn.mul(delta_t4, abar0, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(abar0)
ttnn.deallocate(delta_t4)

abar2 = ttnn.to_memory_config(abar1, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(abar1) # THIS CAUSES A CRASH
abar3 = ttnn.exp(abar2, memory_config=ttnn.L1_MEMORY_CONFIG)
abar4 = ttnn.to_memory_config(abar3, memory_config=self.configs["sharded_large"])
ttnn.deallocate(abar2)

abar4 = ttnn.to_memory_config(abar3, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(abar3) # THIS CAUSES A CRASH

# multiply abar and hidden_state
hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=self.configs["sharded_large"])
amulh0 = ttnn.mul(abar4, hidden_state0, memory_config=self.configs["sharded_large"])
hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=self.configs["sharded_dn"])
amulh0 = ttnn.mul(abar4, hidden_state0, memory_config=self.configs["sharded_dn"])

# deallocate abar and hidden_state
# ttnn.deallocate(abar1) # THIS CAUSES A CRASH
ttnn.deallocate(abar2)
# ttnn.deallocate(abar3) # THIS CAUSES A CRASH

ttnn.deallocate(abar4)
ttnn.deallocate(hidden_state0)

# B
B_proj_weights = ttnn.to_memory_config(self.B_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
B0 = ttnn.linear(x, B_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(B_proj_weights)
B0 = ttnn.linear(
x,
self.B_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col),
)

B1 = ttnn.repeat(B0, ttnn.Shape([1, 1, 1, self.hidden_size], [1, 1, 32, self.hidden_size]))
ttnn.deallocate(B0)
B2 = ttnn.to_memory_config(B1, memory_config=self.configs["sharded_large"])

# Shard B
B2 = ttnn.to_memory_config(B1, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(B1)

# bbar
delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_large"])
# shard delta
delta_t4 = ttnn.to_memory_config(delta_t3, memory_config=self.configs["sharded_dn"])
delta_t3.deallocate()

bbar0 = ttnn.mul(delta_t4, B2, memory_config=self.configs["sharded_large"])
# bbar
bbar0 = ttnn.mul(delta_t4, B2, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(delta_t4)
ttnn.deallocate(B2)

# multiply bbar and x
x0 = ttnn.repeat_interleave(x, self.n, dim=3)
x1 = ttnn.to_memory_config(x0, memory_config=self.configs["sharded_large"])
x1 = ttnn.to_memory_config(x0, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(x0)
bmulx0 = ttnn.mul(bbar0, x1, memory_config=self.configs["sharded_large"])
bmulx0 = ttnn.mul(bbar0, x1, memory_config=self.configs["sharded_dn"])

# deallocate bbar
ttnn.deallocate(bbar0)
ttnn.deallocate(x1)

# add amulh and bmulx
hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=self.configs["sharded_large"])
hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=self.configs["sharded_dn"])
ttnn.deallocate(self.tt_hidden_state)
self.tt_hidden_state = ttnn.to_memory_config(hidden_state1, memory_config=ttnn.DRAM_MEMORY_CONFIG)

Expand All @@ -171,9 +204,15 @@ def forward(self, x):
ttnn.deallocate(bmulx0)

# compute C
C_proj = ttnn.to_memory_config(self.C_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
C0 = ttnn.linear(x, C_proj, memory_config=ttnn.L1_MEMORY_CONFIG) # b,n
ttnn.deallocate(C_proj)
C0 = ttnn.linear(
x,
self.C_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col),
) # b,n
# ttnn.deallocate(C_proj)
C1 = ttnn.permute(C0, (0, 2, 3, 1)) # b,n,1
ttnn.deallocate(C0)

Expand All @@ -188,13 +227,22 @@ def forward(self, x):
C3 = ttnn.permute(C2, (0, 3, 1, 2)) # b, d
ttnn.deallocate(C2)

# shard x
# shard C
x = ttnn.to_memory_config(x, memory_config=self.configs["sharded_d"])
C4 = ttnn.to_memory_config(C3, memory_config=self.configs["sharded_d"])
ttnn.deallocate(C3)

# shard D
D = ttnn.to_memory_config(self.D, memory_config=self.configs["sharded_d"])

# x * D
xD = ttnn.mul(x, self.D, memory_config=ttnn.L1_MEMORY_CONFIG)
xD = ttnn.mul(x, D, memory_config=self.configs["sharded_d"])
ttnn.deallocate(x)

# add xD and x
output = ttnn.add(xD, C3, memory_config=ttnn.L1_MEMORY_CONFIG)
output = ttnn.add(xD, C4, memory_config=self.configs["sharded_d"])
ttnn.deallocate(xD)
ttnn.deallocate(C3)
ttnn.deallocate(C4)

return output
22 changes: 16 additions & 6 deletions models/experimental/mamba/tt_opt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,46 @@ def create_model_config(num_users, hidden_size):

# num_users, hidden_size*2
configs["sharded"] = ttnn.L1_MEMORY_CONFIG
'''
"""
ttnn.create_sharded_memory_config(
shape=(1, 1, num_users, hidden_size*2 // (row * col)),
core_grid=ttnn.CoreGrid(y=row, x=col),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=orientation,
use_height_and_width_as_shard_shape=True,
)
'''
"""

configs["sharded_large"] = ttnn.L1_MEMORY_CONFIG
'''
"""
ttnn.create_sharded_memory_config(
shape=(1, 1, num_users, hidden_size*2 * latent // (row * col)),
core_grid=ttnn.CoreGrid(y=row, x=col),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=orientation,
use_height_and_width_as_shard_shape=True,
)
'''
"""
configs["sharded_rank"] = ttnn.L1_MEMORY_CONFIG
'''
"""
ttnn.create_sharded_memory_config(
shape=(1, 1, hidden_size*2 // 32, hidden_size*2 // (row * col)),
core_grid=ttnn.CoreGrid(y=row, x=col),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=orientation,
use_height_and_width_as_shard_shape=True,
)
'''
"""
configs["sharded_d"] = ttnn.create_sharded_memory_config(
shape=(1, 1, num_users, hidden_size * 2),
core_grid=ttnn.CoreGrid(y=row, x=col),
strategy=ttnn.ShardStrategy.WIDTH,
)
configs["sharded_dn"] = ttnn.create_sharded_memory_config(
shape=(1, 1, num_users, hidden_size * 2 * latent),
core_grid=ttnn.CoreGrid(y=row, x=col),
strategy=ttnn.ShardStrategy.WIDTH,
)
return configs


Expand Down

0 comments on commit d61fc22

Please sign in to comment.