Skip to content

Commit

Permalink
#3812: Use tilize operators for mistral model
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Dec 15, 2023
1 parent bb55199 commit 92b3236
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 174 deletions.
104 changes: 103 additions & 1 deletion models/experimental/mistral/mistral_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
import tt_lib
from typing import Optional
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm, tt2torch_tensor
import torch


def Linear(
Expand Down Expand Up @@ -45,3 +46,104 @@ def linear_(activation):
return output

return linear_


def format_tensor(x, target_layout, device, output_mem_config, pad_value=0.0):
if x.layout() == target_layout:
return x
if x.layout() == tt_lib.tensor.Layout.ROW_MAJOR and target_layout == tt_lib.tensor.Layout.TILE:
x_padded_shape = tt_lib.tensor.pad_to_tile_shape(x.shape(), False, False, True, True)
if x.shape() != x_padded_shape:
return tt_lib.tensor.format_input_tensor(
x, device, x_padded_shape, pad_value, target_layout, output_mem_config
)
else:
return tt_lib.tensor.tilize(x, output_mem_config, use_multicore=True)
elif x.layout() == tt_lib.tensor.Layout.TILE and target_layout == tt_lib.tensor.Layout.ROW_MAJOR:
if x.shape() != x.shape_without_padding():
return tt_lib.tensor.format_output_tensor(
x, x.shape_without_padding(), device, target_layout, output_mem_config
)
else:
return tt_lib.tensor.untilize(x, output_mem_config, use_multicore=True)
else:
assert False


def unpad_from_zero(x, desired_shape):
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]:
x = tt2torch_tensor(x)
else:
x = x.cpu()
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR:
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.to_torch().to(torch.float)
return x


def _reshape_for_broadcast(freqs_cis: torch.Tensor, x_shape, x_ndim) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x_ndim
assert 1 < ndim
assert freqs_cis.shape == (x_shape[1], x_shape[-1]), (
freqs_cis.shape,
(x_shape[1], x_shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x_shape)]
return freqs_cis.view(*shape)


def get_freqs_cis(freqs_cis: torch.Tensor, query_shape, key_shape, device=None, mem_config=None):
freqs_cis = _reshape_for_broadcast(freqs_cis, query_shape, 4)

freq_real = torch_to_tt_tensor_rm(freqs_cis.real, device)
freq_img = torch_to_tt_tensor_rm(freqs_cis.imag, device)
freqs_cis = tt_lib.tensor.complex_tensor(freq_real, freq_img)

freq_real.deallocate()
freq_img.deallocate()

BCH = tt_lib.tensor.BcastOpDim.HW
BCMUL = tt_lib.tensor.BcastOpMath.MUL

t_one_xq = tt_lib.tensor.ones(query_shape, output_mem_config=mem_config)
t_one_xq = tt_lib.tensor.permute(t_one_xq, (3, 1, 2, 0), output_mem_config=mem_config)

freqs_real = tt_lib.tensor.permute(freqs_cis.real, (3, 1, 2, 0), output_mem_config=mem_config)
freqs_imag = tt_lib.tensor.permute(freqs_cis.imag, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_re_xq = tt_lib.tensor.bcast(t_one_xq, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.bcast(t_one_xq, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xq = tt_lib.tensor.permute(bcast_freq_re_xq, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.permute(bcast_freq_im_xq, (3, 1, 2, 0), output_mem_config=mem_config)
t_one_xq.deallocate()

bcast_freq_xq = tt_lib.tensor.complex_tensor(bcast_freq_re_xq, bcast_freq_im_xq)

bcast_freq_re_xq.deallocate()
bcast_freq_im_xq.deallocate()

t_one_xk = tt_lib.tensor.ones(key_shape, output_mem_config=mem_config)
t_one_xk = tt_lib.tensor.permute(t_one_xk, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_re_xk = tt_lib.tensor.bcast(t_one_xk, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.bcast(t_one_xk, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xk = tt_lib.tensor.permute(bcast_freq_re_xk, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.permute(bcast_freq_im_xk, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_xk = tt_lib.tensor.complex_tensor(bcast_freq_re_xk, bcast_freq_im_xk)

t_one_xk.deallocate()
bcast_freq_re_xk.deallocate()
bcast_freq_im_xk.deallocate()
freqs_cis.deallocate()
freqs_real.deallocate()
freqs_imag.deallocate()

return bcast_freq_xq, bcast_freq_xk
20 changes: 17 additions & 3 deletions models/experimental/mistral/tests/test_mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from models.experimental.mistral.tt.mistral_attention import TtAttention
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import Attention
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.utility_functions import torch_to_tt_tensor_rm
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, get_freqs_cis, format_tensor
from models.utility_functions import (
comp_pcc,
comp_allclose,
Expand Down Expand Up @@ -68,25 +69,38 @@ def test_mistral_attention_inference(
model_args.FALLBACK_EMPTY = empty_ondevice
model_args.FALLBACK_SCATTER = scatter_ondevice
model_args.WEIGHTS_DTYPE = dtype
output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtAttention(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.randn(1, 11, 4096)
seqlen = input.shape[1]
empty_tensor = torch.zeros((11, 64))
freqs_cis = torch.complex(empty_tensor, empty_tensor)
query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2]
key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2]
bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config)
positions = torch.arange(0, 11)
mask = torch.randn(11, 11)

reference_output = reference_model(input, freqs_cis, positions, mask=mask)
del reference_model
tt_input = torch_to_tt_tensor_rm(input, device)
tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False)
tt_output = tt_model(tt_input, freqs_cis, tt_position, mask)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)
mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False)
mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000)
tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config)
tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen)
desired_shape = list(reference_output.shape)
desired_shape.insert(0, 1)
tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d
model_args.WEIGHTS_DTYPE = dtype
reference_model = FeedForward(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtFeedForward(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
6 changes: 5 additions & 1 deletion models/experimental/mistral/tests/test_mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset
model_args.max_batch_size = 1
reference_model = RMSNorm(dim=dim)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtRMSNorm(
dim=dim,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
device=device,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
5 changes: 3 additions & 2 deletions models/experimental/mistral/tests/test_mistral_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import Transformer
from models.experimental.mistral.reference.tokenizer import Tokenizer
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero
from models.utility_functions import tt_to_torch_tensor
from models.utility_functions import (
comp_pcc,
Expand Down Expand Up @@ -42,7 +43,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt
model_args.WEIGHTS_DTYPE = dtype
model_args.max_batch_size = 1
model_args.n_layers = 32
model_args.WEIGHTS_DTYPE = dtype

reference_model = Transformer(args=model_args)
reference_model.load_state_dict(state_dict)

Expand All @@ -69,7 +70,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt

reference_output = reference_model(input_tokens[:, :min_prompt_len], positions)

tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0).squeeze(0)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from models.experimental.mistral.tt.mistral_transformer_block import TtTransformerBlock
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import TransformerBlock
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.utility_functions import torch_to_tt_tensor_rm
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, format_tensor, get_freqs_cis
from models.utility_functions import (
comp_pcc,
comp_allclose,
Expand Down Expand Up @@ -38,28 +39,44 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi
model_args.WEIGHTS_DTYPE = dtype
reference_model = TransformerBlock(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtTransformerBlock(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)

input = torch.randn(1, 11, 4096)
seqlen = input.shape[1]
empty_tensor = torch.zeros((11, 64))
freqs_cis = torch.complex(empty_tensor, empty_tensor)
query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2]
key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2]
bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config)
positions = torch.arange(0, 11)
mask = torch.randn(11, 11)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
reference_output = reference_model(input, freqs_cis, positions, mask=mask)

tt_input = torch_to_tt_tensor_rm(input, device)
tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config)
mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False)
mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000)

tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False)
tt_output = tt_model(tt_input, freqs_cis, tt_position, mask)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)
tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen)

desired_shape = list(reference_output.shape)
desired_shape.insert(0, 1)
tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Loading

0 comments on commit 92b3236

Please sign in to comment.