Skip to content

Commit

Permalink
#7012: Add sharding support to Mamba block
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar authored and esmalTT committed Apr 2, 2024
1 parent d61fc22 commit 2a8eac4
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 46 deletions.
24 changes: 15 additions & 9 deletions models/experimental/mamba/tt_opt/full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

import torch

import ttnn
import os

from loguru import logger

from pathlib import Path
from typing import Callable

Expand Down Expand Up @@ -57,7 +58,7 @@ def load_tt_tensor(


class MambaTT(torch.nn.Module):
def __init__(self, reference_model, device: ttnn.device, configs, tt_cache_path: str = "", num_layers=None):
def __init__(self, reference_model, device: ttnn.Device, configs, tt_cache_path: str = "", num_layers=None):
super().__init__()
self.args = reference_model.args
self.device = device
Expand All @@ -67,7 +68,7 @@ def __init__(self, reference_model, device: ttnn.device, configs, tt_cache_path:
self.num_layers = len(reference_model.layers)
else:
self.num_layers = num_layers
print(f"Initalizing MambaTT with {self.num_layers} layers")
logger.info(f"Initalizing MambaTT with {self.num_layers} layers")

self.embedding = reference_model.embedding

Expand All @@ -82,8 +83,13 @@ def __init__(self, reference_model, device: ttnn.device, configs, tt_cache_path:
self.lm_head = reference_model.lm_head

def forward(self, x):
x = self.embedding(x)
x = x.squeeze(1)
assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})"

x = self.embedding(x) # (B, 1, E)
x = x.squeeze(1).unsqueeze(0).unsqueeze(0) # (1, 1, B, E)

assert len(x.shape) == 4, f"Expected embedding to be rank 4 (was {len(x.shape)})"

x = ttnn.from_torch(
x,
device=self.device,
Expand All @@ -94,9 +100,9 @@ def forward(self, x):
for layer in self.layers:
x = layer(x)

x = ttnn.to_torch(x).to(torch.float32)
x = x.unsqueeze(1)
x = self.norm_f(x)
x = ttnn.to_torch(x).to(torch.float32) # (1, 1, B, E)
x = x.squeeze(0).squeeze(0).unsqueeze(1)
x = self.norm_f(x) # (B, 1, E) -> (B, 1, E)
x = self.lm_head(x)

return x
96 changes: 71 additions & 25 deletions models/experimental/mamba/tt_opt/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,37 @@
import torch

import ttnn
import tt_lib as ttl
from typing import Callable

from models.utility_functions import torch2tt_tensor, tt2torch_tensor
from models.helper_funcs import Linear
from models.experimental.mamba.reference.args import ModelArgs
from models.experimental.mamba.tt_opt.mamba_one_step_ssm import TtMambaSSM


class TtMambaBlock(torch.nn.Module):
def __init__(
self,
args: ModelArgs,
device,
configs,
load_fn: Callable
):
def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
super().__init__()

self.device = device
self.args = args
self.num_users = args.batch_size
self.configs = configs

assert self.num_users == 32, "Batch size must be 32 for now"

in_proj_weight_name = "mixer.in_proj.weight"

# ssm wt
self.ssm_in_proj_weights = load_fn(in_proj_weight_name, lambda x: x[: self.args.d_inner, :].transpose(-1, -2), postfix="ssm")
self.ssm_in_proj_weights = load_fn(
in_proj_weight_name, lambda x: x[: self.args.d_inner, :].transpose(-1, -2), postfix="ssm"
)

# mlp wt
self.mlp_proj_weights = load_fn(in_proj_weight_name, lambda x: x[self.args.d_inner :, :].transpose(-1, -2), postfix="mlp")
self.mlp_proj_weights = load_fn(
in_proj_weight_name, lambda x: x[self.args.d_inner :, :].transpose(-1, -2), postfix="mlp"
)

# down proj wt
out_proj_weight_name = "mixer.out_proj.weight"
Expand All @@ -48,10 +49,7 @@ def __init__(
self.conv1d_weights.append(
load_fn(
conv1d_weight_name,
lambda x: x[:, :, i]
.transpose(-1, -2)
.repeat(self.num_users, 1)
.unsqueeze(0).unsqueeze(0),
lambda x: x[:, :, i].transpose(-1, -2).repeat(self.num_users, 1).unsqueeze(0).unsqueeze(0),
postfix=f"{i}_{args.batch_size}",
)
)
Expand All @@ -73,42 +71,90 @@ def __init__(
)
)

self.tt_ssm = TtMambaSSM(self.args,self.device, configs, load_fn)
self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn)

self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.HiFi3,
math_approx_mode=False,
fp32_dest_acc_en=True,
)
self.core_grid_row = 4
self.core_grid_col = 8

def forward(self, x):
assert len(x.shape) == 4, "Mamba block expects inputs to be rank 4"

residual_connection = x # b, e=d_model
residual_connection = x # b, e=d_model

x = ttnn.linear(x, self.ssm_in_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.linear(
x,
self.ssm_in_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
core_grid=ttnn.CoreGrid(y=4, x=8),
)

# shift the states leftward
ttnn.deallocate(self.conv_states[0])
for i in range(3):
self.conv_states[i] = self.conv_states[i + 1]

# update the last state and move it back to DRAM with all the other states
self.conv_states[3] = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG)

x = ttnn.mul(self.conv1d_weights[0], self.conv_states[0], memory_config=ttnn.L1_MEMORY_CONFIG)

for i in range(1,4):
prod = ttnn.mul(self.conv1d_weights[i], self.conv_states[i], memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.add(x, prod, memory_config=ttnn.L1_MEMORY_CONFIG)
# do the convolution
conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[0], memory_config=self.configs["sharded_d"])
conv_state = ttnn.to_memory_config(self.conv_states[0], memory_config=self.configs["sharded_d"])
x = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"])
ttnn.deallocate(conv1d_wt)
ttnn.deallocate(conv_state)

for i in range(1, 4):
conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[i], memory_config=self.configs["sharded_d"])
conv_state = ttnn.to_memory_config(self.conv_states[i], memory_config=self.configs["sharded_d"])
prod = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"])
ttnn.deallocate(conv1d_wt)
ttnn.deallocate(conv_state)

x = ttnn.add(x, prod, memory_config=self.configs["sharded_d"])
ttnn.deallocate(prod)

x = ttnn.add(x, self.conv1d_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
conv1d_bias = ttnn.to_memory_config(self.conv1d_bias, memory_config=self.configs["sharded_d"])
x = ttnn.add(x, conv1d_bias, memory_config=self.configs["sharded_d"])
ttnn.deallocate(conv1d_bias)

x = ttnn.to_memory_config(x, memory_config=ttnn.L1_MEMORY_CONFIG)
x = ttnn.silu(x, memory_config=ttnn.L1_MEMORY_CONFIG)

x = self.tt_ssm(x)

residual = ttnn.linear(residual_connection, self.mlp_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
residual = ttnn.linear(
residual_connection,
self.mlp_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
core_grid=ttnn.CoreGrid(y=4, x=8),
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
)
ttnn.deallocate(residual_connection)

residual_with_silu = ttnn.silu(residual, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(residual)

out = ttnn.mul(x, residual_with_silu, memory_config=ttnn.L1_MEMORY_CONFIG)
residual_with_silu = ttnn.to_memory_config(residual_with_silu, memory_config=self.configs["sharded_d"])
out = ttnn.mul(x, residual_with_silu, memory_config=self.configs["sharded_d"])
ttnn.deallocate(residual_with_silu)
ttnn.deallocate(x)

out = ttnn.linear(out, self.out_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
out = ttnn.to_memory_config(out, memory_config=ttnn.L1_MEMORY_CONFIG)
out = ttnn.linear(
out,
self.out_proj_weights,
memory_config=ttnn.L1_MEMORY_CONFIG,
core_grid=ttnn.CoreGrid(y=4, x=8),
compute_kernel_config=self.compute_kernel_config,
use_1d_systolic_array=True,
)

return out
4 changes: 3 additions & 1 deletion models/experimental/mamba/tt_opt/mamba_one_step_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ def preprocess_A(x):
self.tt_hidden_state = load_fn(f"tt_hidden_state_{args.batch_size}", torch_tensor=prev_hidden_states)

self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=ttl.tensor.MathFidelity.LoFi,
math_fidelity=ttl.tensor.MathFidelity.HiFi3,
math_approx_mode=False,
fp32_dest_acc_en=True,
)
self.core_grid_row = 4
self.core_grid_col = 8

def forward(self, x):
assert len(x.shape) == 4, "SSM block expects inputs to be rank 4"

# delta
delta_t0 = ttnn.linear(
x,
Expand Down
15 changes: 6 additions & 9 deletions models/experimental/mamba/tt_opt/residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@
from models.experimental.mamba.reference.args import ModelArgs
from models.experimental.mamba.tt_opt.mamba_block import TtMambaBlock


class TtResidualBlock(torch.nn.Module):
def __init__(
self,
args: ModelArgs,
device,
configs,
load_fn: Callable
):
def __init__(self, args: ModelArgs, device, configs, load_fn: Callable):
super().__init__()

self.device = device
Expand All @@ -30,10 +25,12 @@ def __init__(
self.tt_mamba_block = TtMambaBlock(self.args, self.device, configs, load_fn)

def forward(self, x):
assert len(x.shape) == 4, "Mamba residual block expects inputs to be rank 4"

mamba_input = x
rms_norm_weights = ttnn.to_memory_config(self.rms_norm_weights, memory_config=ttnn.L1_MEMORY_CONFIG)
mamba_input = ttnn.rms_norm(x, rms_norm_weights, epsilon=self.args.eps)
ttnn.deallocate(rms_norm_weights)

mamba_input = self.tt_mamba_block(mamba_input)
x = ttnn.add(x, mamba_input)
return x
return ttnn.add(x, mamba_input)
4 changes: 2 additions & 2 deletions models/experimental/mamba/tt_opt/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_tt_metal_model(num_users, hidden_size, configs, version):

def run_demo(num_users, hidden_size, profile):
configs = model_config.create_model_config(num_users, hidden_size)
model, device = get_tt_metal_model(num_users, hidden_size, configs, 'mamba-2.8b-slimpj')
model, device = get_tt_metal_model(num_users, hidden_size, configs, "mamba-2.8b-slimpj")

# evaluate model:
model.eval()
Expand All @@ -54,7 +54,7 @@ def run_demo(num_users, hidden_size, profile):
out_data = model(input_data)

ttnn.close_device(device)

return out_data


Expand Down

0 comments on commit 2a8eac4

Please sign in to comment.