Skip to content

Commit

Permalink
#3812: Use tilize operators for mistral model
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Nov 24, 2023
1 parent 8106601 commit acc4292
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 18 deletions.
75 changes: 74 additions & 1 deletion models/experimental/mistral/mistral_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
import tt_lib
from typing import Optional
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm, tt2torch_tensor
import torch


def Linear(
Expand Down Expand Up @@ -45,3 +46,75 @@ def linear_(activation):
return output

return linear_


def format_tensor(x, target_layout, device, output_mem_config, pad_value=0.0):
if x.layout() == target_layout:
return x
if x.layout() == tt_lib.tensor.Layout.ROW_MAJOR and target_layout == tt_lib.tensor.Layout.TILE:
x_padded_shape = tt_lib.tensor.pad_to_tile_shape(x.shape(), False, False, True, True)
if x.shape() != x_padded_shape:
return tt_lib.tensor.format_input_tensor(
x, device, x_padded_shape, pad_value, target_layout, output_mem_config
)
else:
return tt_lib.tensor.tilize(x, output_mem_config, use_multicore=True)
elif x.layout() == tt_lib.tensor.Layout.TILE and target_layout == tt_lib.tensor.Layout.ROW_MAJOR:
if x.shape() != x.shape_without_padding():
return tt_lib.tensor.format_output_tensor(
x, x.shape_without_padding(), device, target_layout, output_mem_config
)
else:
return tt_lib.tensor.untilize(x, output_mem_config, use_multicore=True)
else:
assert False


def unpad_from_zero(x, desired_shape):
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]:
x = tt2torch_tensor(x)
else:
x = x.cpu()
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR:
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.to_torch().to(torch.float)
return x


def format_tensor(x, target_layout, device, output_mem_config, pad_value=0.0):
if x.layout() == target_layout:
return x
if x.layout() == tt_lib.tensor.Layout.ROW_MAJOR and target_layout == tt_lib.tensor.Layout.TILE:
x_padded_shape = tt_lib.tensor.pad_to_tile_shape(x.shape(), False, False, True, True)
if x.shape() != x_padded_shape:
return tt_lib.tensor.format_input_tensor(
x, device, x_padded_shape, pad_value, target_layout, output_mem_config
)
else:
return tt_lib.tensor.tilize(x, output_mem_config, use_multicore=True)
elif x.layout() == tt_lib.tensor.Layout.TILE and target_layout == tt_lib.tensor.Layout.ROW_MAJOR:
if x.shape() != x.shape_without_padding():
return tt_lib.tensor.format_output_tensor(
x, x.shape_without_padding(), device, target_layout, output_mem_config
)
else:
return tt_lib.tensor.untilize(x, output_mem_config, use_multicore=True)
else:
assert False


def unpad_from_zero(x, desired_shape):
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]:
x = tt2torch_tensor(x)
else:
x = x.cpu()
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR:
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.to_torch().to(torch.float)
return x
4 changes: 4 additions & 0 deletions models/experimental/mistral/tests/test_mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ def test_mistral_attention_inference(
model_args.FALLBACK_EMPTY = empty_ondevice
model_args.FALLBACK_SCATTER = scatter_ondevice
model_args.WEIGHTS_DTYPE = dtype
output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtAttention(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.randn(1, 11, 4096)
empty_tensor = torch.zeros((11, 64))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d
model_args.WEIGHTS_DTYPE = dtype
reference_model = FeedForward(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtFeedForward(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
6 changes: 5 additions & 1 deletion models/experimental/mistral/tests/test_mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset
model_args.max_batch_size = 1
reference_model = RMSNorm(dim=dim)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtRMSNorm(
dim=dim,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
device=device,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt
model_args.WEIGHTS_DTYPE = dtype
model_args.max_batch_size = 1
model_args.n_layers = 32
model_args.WEIGHTS_DTYPE = dtype

reference_model = Transformer(args=model_args)
reference_model.load_state_dict(state_dict)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi
model_args.WEIGHTS_DTYPE = dtype
reference_model = TransformerBlock(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtTransformerBlock(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)

input = torch.randn(1, 11, 4096)
Expand Down
54 changes: 49 additions & 5 deletions models/experimental/mistral/tt/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tt_lib.fused_ops.softmax import softmax as Ttsoftmax
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear
from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear, format_tensor, unpad_from_zero


class TtAttention(nn.Module):
Expand All @@ -20,12 +20,13 @@ def __init__(
base_address=None,
device=None,
tt_cache_path=None,
output_mem_config=None,
):
super().__init__()
self.args = args
self.device = device
self.base_address = base_address

self.output_mem_config = output_mem_config
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads

Expand Down Expand Up @@ -118,8 +119,28 @@ def forward(
mask: Optional[torch.Tensor],
) -> tt_lib.tensor.Tensor:
_, bsz, seqlen, _ = x.shape()
x_desired_shape = x.shape()
x = format_tensor(x, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq_desired_shape = x_desired_shape.copy()
xq_desired_shape[-1] = self.wq_weights.shape()[-2]

xk_desired_shape = x_desired_shape.copy()
xk_desired_shape[-1] = self.wk_weights.shape()[-2]

xv_desired_shape = x_desired_shape.copy()
xv_desired_shape[-1] = self.wv_weights.shape()[-2]

xq = unpad_from_zero(xq, xq_desired_shape)
xk = unpad_from_zero(xk, xk_desired_shape)
xv = unpad_from_zero(xv, xv_desired_shape)

xq = torch_to_tt_tensor_rm(xq, self.device, put_on_device=True)
xk = torch_to_tt_tensor_rm(xk, self.device, put_on_device=True)
xv = torch_to_tt_tensor_rm(xv, self.device, put_on_device=True)

xq = tt_lib.tensor.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim)

xk = tt_lib.tensor.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim)
Expand Down Expand Up @@ -167,23 +188,39 @@ def forward(
)

xq = torch_to_tt_tensor_rm(xq, self.device)
xq = format_tensor(xq, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
query = tt_lib.tensor.transpose(xq, 1, -2, output_mem_config=self.args.out_mem_config)
desired_score_shape = query.shape().copy()
desired_score_shape[-1] = key.shape()[1]

xq.deallocate()

key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
value = format_tensor(value, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)

key = tt_lib.tensor.transpose(key, 1, -2, output_mem_config=self.args.out_mem_config)
value = tt_lib.tensor.transpose(value, 1, -2, output_mem_config=self.args.out_mem_config)
key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
value = format_tensor(value, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)

key = tt_lib.tensor.transpose(key, -2, -1, output_mem_config=self.args.out_mem_config)
key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
query = format_tensor(query, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)

scores = tt_lib.tensor.bmm(query, key, output_mem_config=self.args.out_mem_config)
key.deallocate()
scores = tt_lib.tensor.mul_unary(scores, self.scale, output_mem_config=self.args.out_mem_config)
scores = tt_to_torch_tensor(scores)
scores = unpad_from_zero(scores, desired_score_shape)

if mask is not None:
if mask.dim() == 4:
mask = mask.squeeze()
scores += mask[None, None, ...]

scores = torch_to_tt_tensor_rm(scores, self.device, put_on_device=False)

desired_output_shape = scores.shape().copy()
desired_output_shape[-1] = value.shape()[-1]
scores = format_tensor(scores, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config, pad_value=-10000)
if self.args.FALLBACK_SOFTMAX:
scores = fallback_ops.softmax(scores, dim=-1)
else:
Expand All @@ -195,10 +232,17 @@ def forward(

value.deallocate()
scores.deallocate()
output = unpad_from_zero(output, desired_output_shape)
output = torch_to_tt_tensor_rm(output, self.device, put_on_device=False)
output = tt_lib.tensor.transpose(output, 1, -2, output_mem_config=self.args.out_mem_config)

output = fallback_ops.reshape(output, 1, bsz, seqlen, -1)
return self.wo(output)
desired_output_shape = output.shape().copy()
output = format_tensor(output, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
output = self.wo(output)
output = unpad_from_zero(output, desired_output_shape)
output = torch_to_tt_tensor_rm(output, self.device, put_on_device=False)
return output


def _reshape_for_broadcast(freqs_cis: torch.Tensor, x_shape, x_ndim) -> torch.Tensor:
Expand Down
11 changes: 9 additions & 2 deletions models/experimental/mistral/tt/mistral_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tt_lib
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear
from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear, format_tensor, unpad_from_zero


class TtFeedForward(nn.Module):
Expand All @@ -16,11 +16,13 @@ def __init__(
base_address=None,
device=None,
tt_cache_path=None,
output_mem_config=None,
):
super().__init__()
self.device = device
self.args = args

self.output_mem_config = output_mem_config
self.w1_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "w1.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
Expand Down Expand Up @@ -52,6 +54,11 @@ def __init__(
)

def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
desired_shape = x.shape().copy()
x = format_tensor(x, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
silu_out = tt_lib.tensor.silu(self.w1(x))
x = tt_lib.tensor.mul(silu_out, self.w3(x))
return self.w2(x)
out = self.w2(x)
out = unpad_from_zero(out, desired_shape)
out = torch_to_tt_tensor_rm(out, self.device, put_on_device=False)
return out
15 changes: 12 additions & 3 deletions models/experimental/mistral/tt/mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,31 @@
# SPDX-License-Identifier: Apache-2.0
import torch.nn as nn
import tt_lib
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm
from models.utility_functions import torch_to_tt_tensor_rm
from models.experimental.mistral.mistral_helper_funcs import format_tensor, unpad_from_zero


class TtRMSNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
device=None,
base_address=None,
tt_cache_path=None,
output_mem_config=None,
):
super().__init__()
self.eps = eps

self.device = device
self.output_mem_config = output_mem_config
# bfp8 reduces PCC for so using weights in bfloat16
self.weight = tt_lib.tensor.load_tensor(tt_cache_path + base_address + "weightDataType.BFLOAT16.bin")

def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
return tt_lib.tensor.rmsnorm(x, self.eps, self.weight)
desired_shape = x.shape().copy()
x = format_tensor(x, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config)
x = tt_lib.tensor.rmsnorm(x, self.eps, self.weight)
x = unpad_from_zero(x, desired_shape)
x = torch_to_tt_tensor_rm(x, self.device, put_on_device=False)
return x
7 changes: 6 additions & 1 deletion models/experimental/mistral/tt/mistral_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ def __init__(

embedding_weights = state_dict["tok_embeddings.weight"]
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim, _weight=embedding_weights)

self.output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
self.layers = torch.nn.ModuleList(
[
TtTransformerBlock(
args=args,
base_address=f"layers.{i}.",
device=self.device,
tt_cache_path=tt_cache_path,
output_mem_config=self.output_mem_config,
)
for i in range(args.n_layers)
]
Expand All @@ -57,6 +60,8 @@ def __init__(
base_address=f"norm.",
eps=args.norm_eps,
tt_cache_path=tt_cache_path,
device=self.device,
output_mem_config=self.output_mem_config,
)

self.output_weight = tt_lib.tensor.load_tensor(
Expand Down
Loading

0 comments on commit acc4292

Please sign in to comment.