From 59d02dcd8f188a0a320ddb18fcb898853f313ffd Mon Sep 17 00:00:00 2001 From: vigneshkeerthivasanx Date: Thu, 23 Nov 2023 13:00:51 +0000 Subject: [PATCH] #3812: Use tilize operators for mistral model --- .../mistral/mistral_helper_funcs.py | 104 ++++++++- .../mistral/tests/test_mistral_attention.py | 20 +- .../tests/test_mistral_feed_forward.py | 5 +- .../mistral/tests/test_mistral_rms_norm.py | 6 +- .../mistral/tests/test_mistral_transformer.py | 5 +- .../tests/test_mistral_transformer_block.py | 27 ++- .../mistral/tt/mistral_attention.py | 203 ++++++------------ .../mistral/tt/mistral_configuration.py | 1 - .../mistral/tt/mistral_feed_forward.py | 7 +- .../mistral/tt/mistral_rms_norm.py | 11 +- .../mistral/tt/mistral_transformer.py | 37 +++- .../mistral/tt/mistral_transformer_block.py | 30 ++- 12 files changed, 289 insertions(+), 167 deletions(-) diff --git a/models/experimental/mistral/mistral_helper_funcs.py b/models/experimental/mistral/mistral_helper_funcs.py index d9860f8305d..3cd9647b0d5 100644 --- a/models/experimental/mistral/mistral_helper_funcs.py +++ b/models/experimental/mistral/mistral_helper_funcs.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import tt_lib from typing import Optional -from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm +from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm, tt2torch_tensor +import torch def Linear( @@ -45,3 +46,104 @@ def linear_(activation): return output return linear_ + + +def format_tensor(x, target_layout, device, output_mem_config, pad_value=0.0): + if x.layout() == target_layout: + return x + if x.layout() == tt_lib.tensor.Layout.ROW_MAJOR and target_layout == tt_lib.tensor.Layout.TILE: + x_padded_shape = tt_lib.tensor.pad_to_tile_shape(x.shape(), False, False, True, True) + if x.shape() != x_padded_shape: + return tt_lib.tensor.format_input_tensor( + x, device, x_padded_shape, pad_value, target_layout, output_mem_config + ) + else: + return tt_lib.tensor.tilize(x, output_mem_config, use_multicore=True) + elif x.layout() == tt_lib.tensor.Layout.TILE and target_layout == tt_lib.tensor.Layout.ROW_MAJOR: + if x.shape() != x.shape_without_padding(): + return tt_lib.tensor.format_output_tensor( + x, x.shape_without_padding(), device, target_layout, output_mem_config + ) + else: + return tt_lib.tensor.untilize(x, output_mem_config, use_multicore=True) + else: + assert False + + +def unpad_from_zero(x, desired_shape): + if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]: + x = tt2torch_tensor(x) + else: + x = x.cpu() + if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR: + x = x.to(tt_lib.tensor.Layout.ROW_MAJOR) + x = x.unpad( + (0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1) + ) + x = x.to_torch().to(torch.float) + return x + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x_shape, x_ndim) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x_ndim + assert 1 < ndim + assert freqs_cis.shape == (x_shape[1], x_shape[-1]), ( + freqs_cis.shape, + (x_shape[1], x_shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x_shape)] + return freqs_cis.view(*shape) + + +def get_freqs_cis(freqs_cis: torch.Tensor, query_shape, key_shape, device=None, mem_config=None): + freqs_cis = _reshape_for_broadcast(freqs_cis, query_shape, 4) + + freq_real = torch_to_tt_tensor_rm(freqs_cis.real, device) + freq_img = torch_to_tt_tensor_rm(freqs_cis.imag, device) + freqs_cis = tt_lib.tensor.complex_tensor(freq_real, freq_img) + + freq_real.deallocate() + freq_img.deallocate() + + BCH = tt_lib.tensor.BcastOpDim.HW + BCMUL = tt_lib.tensor.BcastOpMath.MUL + + t_one_xq = tt_lib.tensor.ones(query_shape, output_mem_config=mem_config) + t_one_xq = tt_lib.tensor.permute(t_one_xq, (3, 1, 2, 0), output_mem_config=mem_config) + + freqs_real = tt_lib.tensor.permute(freqs_cis.real, (3, 1, 2, 0), output_mem_config=mem_config) + freqs_imag = tt_lib.tensor.permute(freqs_cis.imag, (3, 1, 2, 0), output_mem_config=mem_config) + + bcast_freq_re_xq = tt_lib.tensor.bcast(t_one_xq, freqs_real, BCMUL, BCH, output_mem_config=mem_config) + bcast_freq_im_xq = tt_lib.tensor.bcast(t_one_xq, freqs_imag, BCMUL, BCH, output_mem_config=mem_config) + bcast_freq_re_xq = tt_lib.tensor.permute(bcast_freq_re_xq, (3, 1, 2, 0), output_mem_config=mem_config) + bcast_freq_im_xq = tt_lib.tensor.permute(bcast_freq_im_xq, (3, 1, 2, 0), output_mem_config=mem_config) + t_one_xq.deallocate() + + bcast_freq_xq = tt_lib.tensor.complex_tensor(bcast_freq_re_xq, bcast_freq_im_xq) + + bcast_freq_re_xq.deallocate() + bcast_freq_im_xq.deallocate() + + t_one_xk = tt_lib.tensor.ones(key_shape, output_mem_config=mem_config) + t_one_xk = tt_lib.tensor.permute(t_one_xk, (3, 1, 2, 0), output_mem_config=mem_config) + + bcast_freq_re_xk = tt_lib.tensor.bcast(t_one_xk, freqs_real, BCMUL, BCH, output_mem_config=mem_config) + bcast_freq_im_xk = tt_lib.tensor.bcast(t_one_xk, freqs_imag, BCMUL, BCH, output_mem_config=mem_config) + bcast_freq_re_xk = tt_lib.tensor.permute(bcast_freq_re_xk, (3, 1, 2, 0), output_mem_config=mem_config) + bcast_freq_im_xk = tt_lib.tensor.permute(bcast_freq_im_xk, (3, 1, 2, 0), output_mem_config=mem_config) + + bcast_freq_xk = tt_lib.tensor.complex_tensor(bcast_freq_re_xk, bcast_freq_im_xk) + + t_one_xk.deallocate() + bcast_freq_re_xk.deallocate() + bcast_freq_im_xk.deallocate() + freqs_cis.deallocate() + freqs_real.deallocate() + freqs_imag.deallocate() + + return bcast_freq_xq, bcast_freq_xk diff --git a/models/experimental/mistral/tests/test_mistral_attention.py b/models/experimental/mistral/tests/test_mistral_attention.py index 3e8ff6f695a..5761a126dde 100644 --- a/models/experimental/mistral/tests/test_mistral_attention.py +++ b/models/experimental/mistral/tests/test_mistral_attention.py @@ -10,7 +10,8 @@ from models.experimental.mistral.tt.mistral_attention import TtAttention from models.experimental.mistral.tt.mistral_configuration import TtModelArgs from models.experimental.mistral.reference.model import Attention -from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor +from models.utility_functions import torch_to_tt_tensor_rm +from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, get_freqs_cis, format_tensor from models.utility_functions import ( comp_pcc, comp_allclose, @@ -68,16 +69,24 @@ def test_mistral_attention_inference( model_args.FALLBACK_EMPTY = empty_ondevice model_args.FALLBACK_SCATTER = scatter_ondevice model_args.WEIGHTS_DTYPE = dtype + output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtAttention( args=model_args, device=device, base_address=base_address, tt_cache_path=tt_cache_path, + output_mem_config=output_mem_config, ) input = torch.randn(1, 11, 4096) + seqlen = input.shape[1] empty_tensor = torch.zeros((11, 64)) freqs_cis = torch.complex(empty_tensor, empty_tensor) + query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2] + key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2] + bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config) positions = torch.arange(0, 11) mask = torch.randn(11, 11) @@ -85,8 +94,13 @@ def test_mistral_attention_inference( del reference_model tt_input = torch_to_tt_tensor_rm(input, device) tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False) - tt_output = tt_model(tt_input, freqs_cis, tt_position, mask) - tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0) + mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False) + mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000) + tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config) + tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen) + desired_shape = list(reference_output.shape) + desired_shape.insert(0, 1) + tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0) passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) diff --git a/models/experimental/mistral/tests/test_mistral_feed_forward.py b/models/experimental/mistral/tests/test_mistral_feed_forward.py index 2fcec5dbc6d..71f2fe8ba81 100644 --- a/models/experimental/mistral/tests/test_mistral_feed_forward.py +++ b/models/experimental/mistral/tests/test_mistral_feed_forward.py @@ -36,7 +36,9 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d model_args.WEIGHTS_DTYPE = dtype reference_model = FeedForward(args=model_args) reference_model.load_state_dict(state_dict) - + output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtFeedForward( @@ -44,6 +46,7 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d device=device, base_address=base_address, tt_cache_path=tt_cache_path, + output_mem_config=output_mem_config, ) input = torch.rand(1, 11, 4096) reference_output = reference_model(input) diff --git a/models/experimental/mistral/tests/test_mistral_rms_norm.py b/models/experimental/mistral/tests/test_mistral_rms_norm.py index a87025bbad9..ae13243d9c2 100644 --- a/models/experimental/mistral/tests/test_mistral_rms_norm.py +++ b/models/experimental/mistral/tests/test_mistral_rms_norm.py @@ -33,12 +33,16 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset model_args.max_batch_size = 1 reference_model = RMSNorm(dim=dim) reference_model.load_state_dict(state_dict) - + output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtRMSNorm( dim=dim, base_address=base_address, tt_cache_path=tt_cache_path, + output_mem_config=output_mem_config, + device=device, ) input = torch.rand(1, 11, 4096) reference_output = reference_model(input) diff --git a/models/experimental/mistral/tests/test_mistral_transformer.py b/models/experimental/mistral/tests/test_mistral_transformer.py index 3922634c4d5..a2f39bad03b 100644 --- a/models/experimental/mistral/tests/test_mistral_transformer.py +++ b/models/experimental/mistral/tests/test_mistral_transformer.py @@ -12,6 +12,7 @@ from models.experimental.mistral.tt.mistral_configuration import TtModelArgs from models.experimental.mistral.reference.model import Transformer from models.experimental.mistral.reference.tokenizer import Tokenizer +from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero from models.utility_functions import tt_to_torch_tensor from models.utility_functions import ( comp_pcc, @@ -42,7 +43,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt model_args.WEIGHTS_DTYPE = dtype model_args.max_batch_size = 1 model_args.n_layers = 32 - model_args.WEIGHTS_DTYPE = dtype + reference_model = Transformer(args=model_args) reference_model.load_state_dict(state_dict) @@ -69,7 +70,7 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt reference_output = reference_model(input_tokens[:, :min_prompt_len], positions) - tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0).squeeze(0) + tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0) passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) diff --git a/models/experimental/mistral/tests/test_mistral_transformer_block.py b/models/experimental/mistral/tests/test_mistral_transformer_block.py index dd5a714b38c..64bdf4e2ca7 100644 --- a/models/experimental/mistral/tests/test_mistral_transformer_block.py +++ b/models/experimental/mistral/tests/test_mistral_transformer_block.py @@ -10,7 +10,8 @@ from models.experimental.mistral.tt.mistral_transformer_block import TtTransformerBlock from models.experimental.mistral.tt.mistral_configuration import TtModelArgs from models.experimental.mistral.reference.model import TransformerBlock -from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor +from models.utility_functions import torch_to_tt_tensor_rm +from models.experimental.mistral.mistral_helper_funcs import unpad_from_zero, format_tensor, get_freqs_cis from models.utility_functions import ( comp_pcc, comp_allclose, @@ -38,7 +39,9 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi model_args.WEIGHTS_DTYPE = dtype reference_model = TransformerBlock(args=model_args) reference_model.load_state_dict(state_dict) - + output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtTransformerBlock( @@ -46,20 +49,34 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi device=device, base_address=base_address, tt_cache_path=tt_cache_path, + output_mem_config=output_mem_config, ) input = torch.randn(1, 11, 4096) + seqlen = input.shape[1] empty_tensor = torch.zeros((11, 64)) freqs_cis = torch.complex(empty_tensor, empty_tensor) + query_shape = [1, 11, model_args.n_heads, model_args.head_dim // 2] + key_shape = [1, 11, model_args.n_kv_heads, model_args.head_dim // 2] + bcast_freq_xq, bcast_freq_xk = get_freqs_cis(freqs_cis, query_shape, key_shape, device, output_mem_config) positions = torch.arange(0, 11) mask = torch.randn(11, 11) - + output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) reference_output = reference_model(input, freqs_cis, positions, mask=mask) tt_input = torch_to_tt_tensor_rm(input, device) + tt_input = format_tensor(tt_input, tt_lib.tensor.Layout.TILE, device, output_mem_config) + mask = torch_to_tt_tensor_rm(mask, device, put_on_device=False) + mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, device, output_mem_config, pad_value=-10000) + tt_position = torch_to_tt_tensor_rm(positions, device, put_on_device=False) - tt_output = tt_model(tt_input, freqs_cis, tt_position, mask) - tt_output_torch = tt_to_torch_tensor(tt_output).squeeze(0) + tt_output = tt_model(tt_input, bcast_freq_xq, bcast_freq_xk, tt_position, mask, seqlen) + + desired_shape = list(reference_output.shape) + desired_shape.insert(0, 1) + tt_output_torch = unpad_from_zero(tt_output, desired_shape).squeeze(0) passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) diff --git a/models/experimental/mistral/tt/mistral_attention.py b/models/experimental/mistral/tt/mistral_attention.py index 2c0c3ae15bd..a3c00e3063b 100644 --- a/models/experimental/mistral/tt/mistral_attention.py +++ b/models/experimental/mistral/tt/mistral_attention.py @@ -9,8 +9,8 @@ from tt_lib import fallback_ops from tt_lib.fused_ops.softmax import softmax as Ttsoftmax from models.experimental.mistral.tt.mistral_configuration import TtModelArgs -from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor -from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear +from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor, torch_to_tt_tensor +from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear, format_tensor, unpad_from_zero class TtAttention(nn.Module): @@ -20,12 +20,13 @@ def __init__( base_address=None, device=None, tt_cache_path=None, + output_mem_config=None, ): super().__init__() self.args = args self.device = device self.base_address = base_address - + self.output_mem_config = output_mem_config self.n_heads: int = args.n_heads self.n_kv_heads: int = args.n_kv_heads @@ -113,29 +114,39 @@ def repeat_kv(self, keys: torch.Tensor, values: torch.Tensor, repeats: int) -> t def forward( self, x: tt_lib.tensor.Tensor, - freqs_cis: torch.Tensor, + bcast_freq_xq: tt_lib.tensor.complex_tensor, + bcast_freq_xk: tt_lib.tensor.complex_tensor, positions: tt_lib.tensor.Tensor, mask: Optional[torch.Tensor], + seqlen: int, ) -> tt_lib.tensor.Tensor: - _, bsz, seqlen, _ = x.shape() + _, bsz, _, _ = x.shape() + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = tt_lib.tensor.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim) + xq = tt_to_torch_tensor(xq).to(torch.float32) + xk = tt_to_torch_tensor(xk).to(torch.float32) + xv = tt_to_torch_tensor(xv).to(torch.float32) - xk = tt_lib.tensor.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xq = xq[:, :, :seqlen, :] + xk = xk[:, :, :seqlen, :] + xv = xv[:, :, :seqlen, :] + + xq = torch_to_tt_tensor_rm(xq, self.device, put_on_device=True) + xk = torch_to_tt_tensor_rm(xk, self.device, put_on_device=True) + xv = torch_to_tt_tensor_rm(xv, self.device, put_on_device=True) + xq = tt_lib.tensor.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim) + xk = tt_lib.tensor.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim) xv = tt_lib.tensor.reshape(xv, bsz, seqlen, self.n_kv_heads, self.args.head_dim) xq = tt_to_torch_tensor(xq).to(torch.float32) xk = tt_to_torch_tensor(xk).to(torch.float32) xv = tt_to_torch_tensor(xv).to(torch.float32) - if self.args.FALLBACK_ROTARY_EMBEDDING: - xq, xk = fallback_apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - else: - xq, xk = apply_rotary_emb_type2( - xq, xk, freqs_cis=freqs_cis, device=self.device, mem_config=self.args.out_mem_config - ) + xq, xk = apply_rotary_emb_type2( + xq, xk, bcast_freq_xq, bcast_freq_xk, device=self.device, mem_config=self.args.out_mem_config + ) # The cache is a rotating buffer positions = tt_to_torch_tensor(positions).squeeze(0).squeeze(0).squeeze(0) @@ -166,131 +177,71 @@ def forward( self.cache_k[:bsz, :cur_pos, ...], self.cache_v[:bsz, :cur_pos, ...], self.repeats ) - xq = torch_to_tt_tensor_rm(xq, self.device) + xq = torch_to_tt_tensor(xq, self.device) + query = tt_lib.tensor.transpose(xq, 1, -2, output_mem_config=self.args.out_mem_config) + desired_score_shape = query.shape().copy() + desired_score_shape[-1] = key.shape()[1] xq.deallocate() - key = tt_lib.tensor.transpose(key, 1, -2, output_mem_config=self.args.out_mem_config) + + key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) + key = tt_lib.tensor.permute(key, [0, 2, 3, 1]) + key = format_tensor(key, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) + + value = format_tensor(value, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) value = tt_lib.tensor.transpose(value, 1, -2, output_mem_config=self.args.out_mem_config) + value = format_tensor(value, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) + + query = format_tensor(query, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) - key = tt_lib.tensor.transpose(key, -2, -1, output_mem_config=self.args.out_mem_config) scores = tt_lib.tensor.bmm(query, key, output_mem_config=self.args.out_mem_config) key.deallocate() scores = tt_lib.tensor.mul_unary(scores, self.scale, output_mem_config=self.args.out_mem_config) - scores = tt_to_torch_tensor(scores) - if mask is not None: - if mask.dim() == 4: - mask = mask.squeeze() - scores += mask[None, None, ...] - scores = torch_to_tt_tensor_rm(scores, self.device, put_on_device=False) + if mask is not None: + mask = tt_lib.tensor.permute(mask, [2, 3, 0, 1]) + scores = tt_lib.tensor.permute(scores, [2, 3, 0, 1]) + + scores = tt_lib.tensor.bcast( + scores, + mask, + tt_lib.tensor.BcastOpMath.ADD, + tt_lib.tensor.BcastOpDim.HW, + output_mem_config=self.output_mem_config, + ) + scores = tt_lib.tensor.permute(scores, [2, 3, 0, 1]) + desired_output_shape = [bsz, 32, seqlen, seqlen] + desired_output_shape[-1] = value.shape()[-1] if self.args.FALLBACK_SOFTMAX: scores = fallback_ops.softmax(scores, dim=-1) else: scores = tt_lib.tensor.softmax(scores, output_mem_config=self.args.out_mem_config) - output = tt_lib.tensor.bmm( scores, value, output_mem_config=self.args.out_mem_config ) # (bs, n_local_heads, slen, head_dim) value.deallocate() scores.deallocate() + output = unpad_from_zero(output, desired_output_shape) + output = torch_to_tt_tensor_rm(output, self.device, put_on_device=False) + output = tt_lib.tensor.transpose(output, 1, -2, output_mem_config=self.args.out_mem_config) output = fallback_ops.reshape(output, 1, bsz, seqlen, -1) - return self.wo(output) - - -def _reshape_for_broadcast(freqs_cis: torch.Tensor, x_shape, x_ndim) -> torch.Tensor: - """ - freqs_cis: complex - (seq_len, head_dim / 2) - x: complex - (bsz, seq_len, head_dim / 2) - """ - ndim = x_ndim - assert 1 < ndim - assert freqs_cis.shape == (x_shape[1], x_shape[-1]), ( - freqs_cis.shape, - (x_shape[1], x_shape[-1]), - ) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x_shape)] - return freqs_cis.view(*shape) - - -def fallback_apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = _reshape_for_broadcast(freqs_cis, xq_.shape, xq_.ndim) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def apply_rotary_emb_type1( - t_xq: torch.Tensor, t_xk: torch.Tensor, freqs_cis: torch.Tensor, device, mem_config -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_shape = list(copy.deepcopy(t_xq.shape)) - xq_shape[-1] = xq_shape[-1] // 2 - freqs_cis = _reshape_for_broadcast(freqs_cis, xq_shape, 4) - - freq_real = torch_to_tt_tensor_rm(freqs_cis.real, device) - freq_img = torch_to_tt_tensor_rm(freqs_cis.imag, device) - freqs_cis = tt_lib.tensor.concat([freq_real, freq_img], -1) - - xq_real = torch_to_tt_tensor_rm(t_xq[..., :, :, ::2], device) - xq_img = torch_to_tt_tensor_rm(t_xq[..., :, :, 1::2], device) - xq = tt_lib.tensor.concat([xq_real, xq_img], -1) - - xk_real = torch_to_tt_tensor_rm(t_xk[..., :, :, ::2], device) - xk_img = torch_to_tt_tensor_rm(t_xk[..., :, :, 1::2], device) - xk = tt_lib.tensor.concat([xk_real, xk_img], -1) - - BCH = tt_lib.tensor.BcastOpDim.H - BCMUL = tt_lib.tensor.BcastOpMath.MUL - - t_one = tt_lib.tensor.ones_like(xq) - bcast_freq = tt_lib.tensor.bcast(t_one, freqs_cis, BCMUL, BCH) - xq_out = tt_lib.tensor.complex_mul(xq, bcast_freq, mem_config) - - t_one = tt_lib.tensor.ones_like(xk) - bcast_freq = tt_lib.tensor.bcast(t_one, freqs_cis, BCMUL, BCH) - xk_out = tt_lib.tensor.complex_mul(xk, bcast_freq, mem_config) - - xq, xk = tt_to_torch_tensor(xq_out).to(torch.float32), tt_to_torch_tensor(xk_out).to(torch.float32) - - shapes = xq.shape - dindex = shapes[3] // 2 - xq_out = torch.empty(xq.shape) - xq_out[:, :, :, ::2] = xq[:, :, :, :dindex] - xq_out[:, :, :, 1::2] = xq[:, :, :, dindex:] - - shapes = xk.shape - dindex = shapes[3] // 2 - xk_out = torch.empty(xk.shape) - xk_out[:, :, :, ::2] = xk[:, :, :, :dindex] - xk_out[:, :, :, 1::2] = xk[:, :, :, dindex:] - return xq_out, xk_out + desired_output_shape = output.shape().copy() + output = format_tensor(output, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) + output = self.wo(output) + return output def apply_rotary_emb_type2( - t_xq: torch.Tensor, t_xk: torch.Tensor, freqs_cis: torch.Tensor, device, mem_config + t_xq: torch.Tensor, t_xk: torch.Tensor, bcast_freq_xq, bcast_freq_xk, device, mem_config ) -> Tuple[torch.Tensor, torch.Tensor]: - xq_shape = list(copy.deepcopy(t_xq.shape)) - xq_shape[-1] = xq_shape[-1] // 2 - freqs_cis = _reshape_for_broadcast(freqs_cis, xq_shape, 4) - - freq_real = torch_to_tt_tensor_rm(freqs_cis.real, device) - freq_img = torch_to_tt_tensor_rm(freqs_cis.imag, device) - freqs_cis = tt_lib.tensor.complex_tensor(freq_real, freq_img) - - freq_real.deallocate() - freq_img.deallocate() - xq_real = torch_to_tt_tensor_rm(t_xq[..., :, :, ::2], device) xq_img = torch_to_tt_tensor_rm(t_xq[..., :, :, 1::2], device) + xq = tt_lib.tensor.complex_tensor(xq_real, xq_img) xq_real.deallocate() @@ -303,40 +254,10 @@ def apply_rotary_emb_type2( xk_real.deallocate() xk_img.deallocate() - BCH = tt_lib.tensor.BcastOpDim.HW - BCMUL = tt_lib.tensor.BcastOpMath.MUL - - t_one_xq = tt_lib.tensor.ones_like(xq.real, output_mem_config=mem_config) - t_one_xq = tt_lib.tensor.permute(t_one_xq, (3, 1, 2, 0), output_mem_config=mem_config) - freqs_real = tt_lib.tensor.permute(freqs_cis.real, (3, 1, 2, 0), output_mem_config=mem_config) - freqs_imag = tt_lib.tensor.permute(freqs_cis.imag, (3, 1, 2, 0), output_mem_config=mem_config) - bcast_freq_re_xq = tt_lib.tensor.bcast(t_one_xq, freqs_real, BCMUL, BCH, output_mem_config=mem_config) - bcast_freq_im_xq = tt_lib.tensor.bcast(t_one_xq, freqs_imag, BCMUL, BCH, output_mem_config=mem_config) - bcast_freq_re_xq = tt_lib.tensor.permute(bcast_freq_re_xq, (3, 1, 2, 0), output_mem_config=mem_config) - bcast_freq_im_xq = tt_lib.tensor.permute(bcast_freq_im_xq, (3, 1, 2, 0), output_mem_config=mem_config) - t_one_xq.deallocate() - bcast_freq_xq = tt_lib.tensor.complex_tensor(bcast_freq_re_xq, bcast_freq_im_xq) - bcast_freq_re_xq.deallocate() - bcast_freq_im_xq.deallocate() xq_out = tt_lib.tensor.complex_mul(xq, bcast_freq_xq, output_mem_config=mem_config) - bcast_freq_xq.deallocate() - - t_one_xk = tt_lib.tensor.ones_like(xk.real, output_mem_config=mem_config) - t_one_xk = tt_lib.tensor.permute(t_one_xk, (3, 1, 2, 0), output_mem_config=mem_config) - bcast_freq_re_xk = tt_lib.tensor.bcast(t_one_xk, freqs_real, BCMUL, BCH, output_mem_config=mem_config) - bcast_freq_im_xk = tt_lib.tensor.bcast(t_one_xk, freqs_imag, BCMUL, BCH, output_mem_config=mem_config) - bcast_freq_re_xk = tt_lib.tensor.permute(bcast_freq_re_xk, (3, 1, 2, 0), output_mem_config=mem_config) - bcast_freq_im_xk = tt_lib.tensor.permute(bcast_freq_im_xk, (3, 1, 2, 0), output_mem_config=mem_config) - bcast_freq_xk = tt_lib.tensor.complex_tensor(bcast_freq_re_xk, bcast_freq_im_xk) - t_one_xk.deallocate() - bcast_freq_re_xk.deallocate() - bcast_freq_im_xk.deallocate() - freqs_cis.deallocate() - freqs_real.deallocate() - freqs_imag.deallocate() + xk_out = tt_lib.tensor.complex_mul(xk, bcast_freq_xk, output_mem_config=mem_config) - bcast_freq_xk.deallocate() xq_out = tt_lib.tensor.concat([xq_out.real, xq_out.imag], -1, mem_config) xk_out = tt_lib.tensor.concat([xk_out.real, xk_out.imag], -1, mem_config) xq, xk = tt_to_torch_tensor(xq_out).to(torch.float32), tt_to_torch_tensor(xk_out).to(torch.float32) @@ -344,6 +265,7 @@ def apply_rotary_emb_type2( xq_out.deallocate() xk_out.deallocate() # FIXME: move this operation to on-device - should be easy. + shapes = xq.shape dindex = shapes[3] // 2 xq_out = torch.empty(xq.shape) @@ -358,4 +280,5 @@ def apply_rotary_emb_type2( xk_out = torch.empty(xk.shape) xk_out[:, :, :, ::2] = xk[:, :, :, :dindex] xk_out[:, :, :, 1::2] = xk[:, :, :, dindex:] + return xq_out, xk_out diff --git a/models/experimental/mistral/tt/mistral_configuration.py b/models/experimental/mistral/tt/mistral_configuration.py index 74a09609702..f29eb5f63b3 100644 --- a/models/experimental/mistral/tt/mistral_configuration.py +++ b/models/experimental/mistral/tt/mistral_configuration.py @@ -19,7 +19,6 @@ class TtModelArgs: max_batch_size: int = 0 FALLBACK_SOFTMAX: bool = False - FALLBACK_ROTARY_EMBEDDING: bool = False FALLBACK_EMPTY: bool = False FALLBACK_SCATTER: bool = False FALLBACK_DRAM: bool = True diff --git a/models/experimental/mistral/tt/mistral_feed_forward.py b/models/experimental/mistral/tt/mistral_feed_forward.py index 8ba8ce92b09..200f178129f 100644 --- a/models/experimental/mistral/tt/mistral_feed_forward.py +++ b/models/experimental/mistral/tt/mistral_feed_forward.py @@ -6,7 +6,7 @@ import tt_lib from models.experimental.mistral.tt.mistral_configuration import TtModelArgs from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor -from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear +from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear, format_tensor, unpad_from_zero class TtFeedForward(nn.Module): @@ -16,11 +16,13 @@ def __init__( base_address=None, device=None, tt_cache_path=None, + output_mem_config=None, ): super().__init__() self.device = device self.args = args + self.output_mem_config = output_mem_config self.w1_weights = tt_lib.tensor.load_tensor( tt_cache_path + base_address + "w1.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) @@ -54,4 +56,5 @@ def __init__( def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: silu_out = tt_lib.tensor.silu(self.w1(x)) x = tt_lib.tensor.mul(silu_out, self.w3(x)) - return self.w2(x) + out = self.w2(x) + return out diff --git a/models/experimental/mistral/tt/mistral_rms_norm.py b/models/experimental/mistral/tt/mistral_rms_norm.py index c323e434e13..67c35592067 100644 --- a/models/experimental/mistral/tt/mistral_rms_norm.py +++ b/models/experimental/mistral/tt/mistral_rms_norm.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch.nn as nn import tt_lib -from models.utility_functions import tt_to_torch_tensor, torch_to_tt_tensor_rm +from models.utility_functions import torch_to_tt_tensor_rm +from models.experimental.mistral.mistral_helper_funcs import format_tensor, unpad_from_zero class TtRMSNorm(nn.Module): @@ -11,14 +12,18 @@ def __init__( self, dim: int, eps: float = 1e-6, + device=None, base_address=None, tt_cache_path=None, + output_mem_config=None, ): super().__init__() self.eps = eps - + self.device = device + self.output_mem_config = output_mem_config # bfp8 reduces PCC for so using weights in bfloat16 self.weight = tt_lib.tensor.load_tensor(tt_cache_path + base_address + "weightDataType.BFLOAT16.bin") def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: - return tt_lib.tensor.rmsnorm(x, self.eps, self.weight) + x = tt_lib.tensor.rmsnorm(x, self.eps, self.weight) + return x diff --git a/models/experimental/mistral/tt/mistral_transformer.py b/models/experimental/mistral/tt/mistral_transformer.py index 7b87ee2307c..037755fe5f7 100644 --- a/models/experimental/mistral/tt/mistral_transformer.py +++ b/models/experimental/mistral/tt/mistral_transformer.py @@ -8,7 +8,12 @@ from models.experimental.mistral.tt.mistral_configuration import TtModelArgs from models.experimental.mistral.tt.mistral_transformer_block import TtTransformerBlock from models.experimental.mistral.tt.mistral_rms_norm import TtRMSNorm -from models.experimental.mistral.mistral_helper_funcs import Linear as TtLinear +from models.experimental.mistral.mistral_helper_funcs import ( + Linear as TtLinear, + format_tensor, + unpad_from_zero, + get_freqs_cis, +) from models.utility_functions import torch_to_tt_tensor_rm, tt_to_torch_tensor from typing import Optional @@ -38,7 +43,9 @@ def __init__( embedding_weights = torch.load(tt_cache_path + "tok_embeddings.weight.pt") self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim, _weight=embedding_weights) - + self.output_mem_config = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM + ) self.layers = torch.nn.ModuleList( [ TtTransformerBlock( @@ -46,6 +53,7 @@ def __init__( base_address=f"layers.{i}.", device=self.device, tt_cache_path=tt_cache_path, + output_mem_config=self.output_mem_config, ) for i in range(args.n_layers) ] @@ -55,6 +63,8 @@ def __init__( base_address=f"norm.", eps=args.norm_eps, tt_cache_path=tt_cache_path, + device=self.device, + output_mem_config=self.output_mem_config, ) self.output_weight = tt_lib.tensor.load_tensor( @@ -72,9 +82,17 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, ): + seqlen = input_ids.shape[-1] + bsz = input_ids.shape[0] h = self.tok_embeddings(input_ids) input_ids = torch_to_tt_tensor_rm(input_ids, self.device, put_on_device=False) freqs_cis = self.freqs_cis[positions] + query_shape = [bsz, seqlen, self.args.n_heads, self.args.head_dim // 2] + key_shape = [bsz, seqlen, self.args.n_kv_heads, self.args.head_dim // 2] + bcast_freq_xq, bcast_freq_xk = get_freqs_cis( + freqs_cis, query_shape, key_shape, self.device, self.output_mem_config + ) + mask: Optional[torch.Tensor] = None if input_ids.shape()[-1] > 1: seqlen = input_ids.shape()[-1] @@ -91,9 +109,20 @@ def forward( mask = tt_lib.tensor.triu(mask, diagonal) mask = tt_lib.tensor.log(mask) mask = tt_to_torch_tensor(mask) + mask = torch_to_tt_tensor_rm(mask, self.device, put_on_device=False) + mask = format_tensor(mask, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config, pad_value=-10000) positions = torch_to_tt_tensor_rm(positions, self.device, put_on_device=False) h = torch_to_tt_tensor_rm(h, self.device, put_on_device=False) + h = format_tensor(h, tt_lib.tensor.Layout.TILE, self.device, self.output_mem_config) for layer in self.layers: - h = layer(h, freqs_cis, positions, mask) - return self.output(self.norm(h)) + h = layer(h, bcast_freq_xq, bcast_freq_xk, positions, mask, seqlen) + + bcast_freq_xq.deallocate() + bcast_freq_xk.deallocate() + output = self.output(self.norm(h)) + desired_output_shape = list(output.shape()) + desired_output_shape[2] = seqlen + output = unpad_from_zero(output, desired_output_shape) + output = torch_to_tt_tensor_rm(output, self.device, put_on_device=False) + return output diff --git a/models/experimental/mistral/tt/mistral_transformer_block.py b/models/experimental/mistral/tt/mistral_transformer_block.py index 2a3cce715ec..242cc01fcbd 100644 --- a/models/experimental/mistral/tt/mistral_transformer_block.py +++ b/models/experimental/mistral/tt/mistral_transformer_block.py @@ -9,6 +9,8 @@ from models.experimental.mistral.tt.mistral_feed_forward import TtFeedForward from models.experimental.mistral.tt.mistral_rms_norm import TtRMSNorm from models.experimental.mistral.tt.mistral_configuration import TtModelArgs +from models.experimental.mistral.mistral_helper_funcs import format_tensor, unpad_from_zero +from models.utility_functions import torch_to_tt_tensor_rm class TtTransformerBlock(nn.Module): @@ -18,35 +20,55 @@ def __init__( device=None, base_address=None, tt_cache_path=None, + output_mem_config=None, ): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.device = device - self.attention = TtAttention(args, f"{base_address}attention.", device, tt_cache_path=tt_cache_path) - self.feed_forward = TtFeedForward(args, f"{base_address}feed_forward.", device, tt_cache_path=tt_cache_path) + self.output_mem_config = output_mem_config + self.attention = TtAttention( + args, + f"{base_address}attention.", + device, + tt_cache_path=tt_cache_path, + output_mem_config=self.output_mem_config, + ) + self.feed_forward = TtFeedForward( + args, + f"{base_address}feed_forward.", + device, + tt_cache_path=tt_cache_path, + output_mem_config=self.output_mem_config, + ) self.attention_norm = TtRMSNorm( args.dim, base_address=f"{base_address}attention_norm.", eps=args.norm_eps, tt_cache_path=tt_cache_path, + device=self.device, + output_mem_config=self.output_mem_config, ) self.ffn_norm = TtRMSNorm( args.dim, base_address=f"{base_address}ffn_norm.", eps=args.norm_eps, tt_cache_path=tt_cache_path, + device=self.device, + output_mem_config=self.output_mem_config, ) self.args = args def forward( self, x: tt_lib.tensor.Tensor, - freqs_cis: torch.Tensor, + bcast_freq_xq: tt_lib.tensor.complex_tensor, + bcast_freq_xk: tt_lib.tensor.complex_tensor, positions: tt_lib.tensor.Tensor, mask: Optional[torch.Tensor], + seqlen: int, ) -> tt_lib.tensor.Tensor: - r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask) + r = self.attention.forward(self.attention_norm(x), bcast_freq_xq, bcast_freq_xk, positions, mask, seqlen) h = tt_lib.tensor.add(x, r) x.deallocate() r = self.feed_forward.forward(self.ffn_norm(h))