Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#3812: Use tilize operators for mistral model #4029

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion models/experimental/mistral/mistral_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
import tt_lib
from typing import Optional
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm
from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm, tt2torch_tensor
import torch


def Linear(
Expand Down Expand Up @@ -45,3 +46,104 @@ def linear_(activation):
return output

return linear_


def format_tensor(x, target_layout, device, output_mem_config, pad_value=0.0):
if x.layout() == target_layout:
return x
if x.layout() == tt_lib.tensor.Layout.ROW_MAJOR and target_layout == tt_lib.tensor.Layout.TILE:
x_padded_shape = tt_lib.tensor.pad_to_tile_shape(x.shape(), False, False, True, True)
if x.shape() != x_padded_shape:
return tt_lib.tensor.format_input_tensor(
x, device, x_padded_shape, pad_value, target_layout, output_mem_config
)
else:
return tt_lib.tensor.tilize(x, output_mem_config, use_multicore=True)
elif x.layout() == tt_lib.tensor.Layout.TILE and target_layout == tt_lib.tensor.Layout.ROW_MAJOR:
if x.shape() != x.shape_without_padding():
return tt_lib.tensor.format_output_tensor(
x, x.shape_without_padding(), device, target_layout, output_mem_config
)
else:
return tt_lib.tensor.untilize(x, output_mem_config, use_multicore=True)
else:
assert False


def unpad_from_zero(x, desired_shape):
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]:
x = tt2torch_tensor(x)
else:
x = x.cpu()
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR:
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.to_torch().to(torch.float)
return x


def _reshape_for_broadcast(freqs_cis: torch.Tensor, x_shape, x_ndim) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x_ndim
assert 1 < ndim
assert freqs_cis.shape == (x_shape[1], x_shape[-1]), (
freqs_cis.shape,
(x_shape[1], x_shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x_shape)]
return freqs_cis.view(*shape)


def get_freqs_cis(freqs_cis: torch.Tensor, query_shape, key_shape, device=None, mem_config=None):
freqs_cis = _reshape_for_broadcast(freqs_cis, query_shape, 4)

freq_real = torch_to_tt_tensor_rm(freqs_cis.real, device)
freq_img = torch_to_tt_tensor_rm(freqs_cis.imag, device)
freqs_cis = tt_lib.tensor.complex_tensor(freq_real, freq_img)

freq_real.deallocate()
freq_img.deallocate()

BCH = tt_lib.tensor.BcastOpDim.HW
BCMUL = tt_lib.tensor.BcastOpMath.MUL

t_one_xq = tt_lib.tensor.ones(query_shape, output_mem_config=mem_config)
t_one_xq = tt_lib.tensor.permute(t_one_xq, (3, 1, 2, 0), output_mem_config=mem_config)

freqs_real = tt_lib.tensor.permute(freqs_cis.real, (3, 1, 2, 0), output_mem_config=mem_config)
freqs_imag = tt_lib.tensor.permute(freqs_cis.imag, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_re_xq = tt_lib.tensor.bcast(t_one_xq, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.bcast(t_one_xq, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xq = tt_lib.tensor.permute(bcast_freq_re_xq, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xq = tt_lib.tensor.permute(bcast_freq_im_xq, (3, 1, 2, 0), output_mem_config=mem_config)
t_one_xq.deallocate()

bcast_freq_xq = tt_lib.tensor.complex_tensor(bcast_freq_re_xq, bcast_freq_im_xq)

bcast_freq_re_xq.deallocate()
bcast_freq_im_xq.deallocate()

t_one_xk = tt_lib.tensor.ones(key_shape, output_mem_config=mem_config)
t_one_xk = tt_lib.tensor.permute(t_one_xk, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_re_xk = tt_lib.tensor.bcast(t_one_xk, freqs_real, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.bcast(t_one_xk, freqs_imag, BCMUL, BCH, output_mem_config=mem_config)
bcast_freq_re_xk = tt_lib.tensor.permute(bcast_freq_re_xk, (3, 1, 2, 0), output_mem_config=mem_config)
bcast_freq_im_xk = tt_lib.tensor.permute(bcast_freq_im_xk, (3, 1, 2, 0), output_mem_config=mem_config)

bcast_freq_xk = tt_lib.tensor.complex_tensor(bcast_freq_re_xk, bcast_freq_im_xk)

t_one_xk.deallocate()
bcast_freq_re_xk.deallocate()
bcast_freq_im_xk.deallocate()
freqs_cis.deallocate()
freqs_real.deallocate()
freqs_imag.deallocate()

return bcast_freq_xq, bcast_freq_xk
20 changes: 17 additions & 3 deletions models/experimental/mistral/tests/test_mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from models.experimental.mistral.tt.mistral_attention import TtAttention
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import Attention
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.utility_functions import torch_to_tt_tensor_rm
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, get_freqs_cis, format_tensor
from models.utility_functions import (
comp_pcc,
comp_allclose,
Expand Down Expand Up @@ -68,25 +69,38 @@ def test_mistral_attention_inference(
model_args.FALLBACK_EMPTY = empty_ondevice
model_args.FALLBACK_SCATTER = scatter_ondevice
model_args.WEIGHTS_DTYPE = dtype
output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtAttention(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.randn(1, 11, 4096)
seqlen = input.shape[1]
empty_tensor = torch.zeros((11, 64))
freqs_cis = torch.complex(empty_tensor, empty_tensor)
query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2]
key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2]
bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config)
positions = torch.arange(0, 11)
mask = torch.randn(11, 11)

reference_output = reference_model(input, freqs_cis, positions, mask=mask)
del reference_model
tt_input = torch_to_tt_tensor_rm(input, device)
tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False)
tt_output = tt_model(tt_input, freqs_cis, tt_position, mask)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)
mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False)
mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000)
tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config)
tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen)
desired_shape = list(reference_output.shape)
desired_shape.insert(0, 1)
tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d
model_args.WEIGHTS_DTYPE = dtype
reference_model = FeedForward(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtFeedForward(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
6 changes: 5 additions & 1 deletion models/experimental/mistral/tests/test_mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset
model_args.max_batch_size = 1
reference_model = RMSNorm(dim=dim)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtRMSNorm(
dim=dim,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
device=device,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
5 changes: 3 additions & 2 deletions models/experimental/mistral/tests/test_mistral_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import Transformer
from models.experimental.mistral.reference.tokenizer import Tokenizer
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero
from models.utility_functions import tt_to_torch_tensor
from models.utility_functions import (
comp_pcc,
Expand Down Expand Up @@ -42,7 +43,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt
model_args.WEIGHTS_DTYPE = dtype
model_args.max_batch_size = 1
model_args.n_layers = 32
model_args.WEIGHTS_DTYPE = dtype

reference_model = Transformer(args=model_args)
reference_model.load_state_dict(state_dict)

Expand All @@ -69,7 +70,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt

reference_output = reference_model(input_tokens[:, :min_prompt_len], positions)

tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0).squeeze(0)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from models.experimental.mistral.tt.mistral_transformer_block import TtTransformerBlock
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from models.experimental.mistral.reference.model import TransformerBlock
from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor
from models.utility_functions import torch_to_tt_tensor_rm
from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, format_tensor, get_freqs_cis
from models.utility_functions import (
comp_pcc,
comp_allclose,
Expand Down Expand Up @@ -38,28 +39,44 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi
model_args.WEIGHTS_DTYPE = dtype
reference_model = TransformerBlock(args=model_args)
reference_model.load_state_dict(state_dict)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtTransformerBlock(
args=model_args,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
output_mem_config=output_mem_config,
)

input = torch.randn(1, 11, 4096)
seqlen = input.shape[1]
empty_tensor = torch.zeros((11, 64))
freqs_cis = torch.complex(empty_tensor, empty_tensor)
query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2]
key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2]
bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config)
positions = torch.arange(0, 11)
mask = torch.randn(11, 11)

output_mem_config = tt_lib.tensor.MemoryConfig(
tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM
)
reference_output = reference_model(input, freqs_cis, positions, mask=mask)

tt_input = torch_to_tt_tensor_rm(input, device)
tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config)
mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False)
mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000)

tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False)
tt_output = tt_model(tt_input, freqs_cis, tt_position, mask)
tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0)
tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen)

desired_shape = list(reference_output.shape)
desired_shape.insert(0, 1)
tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0)

passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc)

Expand Down
Loading