From 8106601f9761b02106135ceae5ee02b28cdb745e Mon Sep 17 00:00:00 2001 From: vigneshkeerthivasanx Date: Tue, 21 Nov 2023 09:57:54 +0000 Subject: [PATCH] #3824: cache weight tensors for mistral --- models/experimental/mistral/demo/gs_demo.py | 2 + models/experimental/mistral/mistral_utils.py | 36 +++++++++++++++ .../mistral/tests/test_mistral_attention.py | 5 +- .../tests/test_mistral_feed_forward.py | 6 ++- .../mistral/tests/test_mistral_rms_norm.py | 6 +-- .../mistral/tests/test_mistral_transformer.py | 5 ++ .../tests/test_mistral_transformer_block.py | 6 ++- .../mistral/tests/test_perf_mistral.py | 2 + .../mistral/tt/mistral_attention.py | 46 +++++-------------- .../mistral/tt/mistral_feed_forward.py | 30 ++++-------- .../mistral/tt/mistral_rms_norm.py | 15 ++---- .../mistral/tt/mistral_transformer.py | 23 ++++++---- .../mistral/tt/mistral_transformer_block.py | 14 +++--- 13 files changed, 104 insertions(+), 92 deletions(-) diff --git a/models/experimental/mistral/demo/gs_demo.py b/models/experimental/mistral/demo/gs_demo.py index 98dfe9a3342..da19d3bce31 100644 --- a/models/experimental/mistral/demo/gs_demo.py +++ b/models/experimental/mistral/demo/gs_demo.py @@ -35,11 +35,13 @@ def test_gs_demo_single_input_inference(batch_size, model_location_generator, de model_args.max_batch_size = batch_size model_args.n_layers = 32 + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtTransformer( args=model_args, state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) tt_output = generate( diff --git a/models/experimental/mistral/mistral_utils.py b/models/experimental/mistral/mistral_utils.py index b75659c9779..c39ac7393fb 100644 --- a/models/experimental/mistral/mistral_utils.py +++ b/models/experimental/mistral/mistral_utils.py @@ -9,6 +9,10 @@ from sentencepiece import SentencePieceProcessor from pathlib import Path from models.utility_functions import tt_to_torch_tensor +import tt_lib +import json +from models.experimental.mistral.tt.mistral_configuration import TtModelArgs +from tt_lib.utils import pad_weight class Tokenizer: @@ -76,3 +80,35 @@ def generate(prompts: List[str], model: TtTransformer, tokenizer: Tokenizer, max for i, x in enumerate(encoded_prompts): res.append(tokenizer.decode(x[:min_prompt_len] + generated[i].tolist())) return res + + +def cache_weights_in_weka(model_location_generator, device, dtype, reset_seeds): + mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") + state_dict = torch.load(mistral_path / "consolidated.00.pth") + with open(mistral_path / "params.json", "r") as f: + model_args = TtModelArgs(**json.loads(f.read())) + weights_dtype = dtype + + # initial weights are stored in "models/experimental/mistral/weights/" and moved to weka path + file_name = "models/experimental/mistral/weights/" + for key, value in state_dict.items(): + if len(value.shape) == 1: + value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0) + else: + value = value.unsqueeze(0).unsqueeze(0) + if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + else: + value = pad_weight(value) + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) diff --git a/models/experimental/mistral/tests/test_mistral_attention.py b/models/experimental/mistral/tests/test_mistral_attention.py index 32f46013c70..3e8ff6f695a 100644 --- a/models/experimental/mistral/tests/test_mistral_attention.py +++ b/models/experimental/mistral/tests/test_mistral_attention.py @@ -54,7 +54,7 @@ def test_mistral_attention_inference( ): mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") state_dict = torch.load(mistral_path / "consolidated.00.pth") - base_address = f"" + base_address = f"layers.0.attention." with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) if True: @@ -68,11 +68,12 @@ def test_mistral_attention_inference( model_args.FALLBACK_EMPTY = empty_ondevice model_args.FALLBACK_SCATTER = scatter_ondevice model_args.WEIGHTS_DTYPE = dtype + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtAttention( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) input = torch.randn(1, 11, 4096) empty_tensor = torch.zeros((11, 64)) diff --git a/models/experimental/mistral/tests/test_mistral_feed_forward.py b/models/experimental/mistral/tests/test_mistral_feed_forward.py index 0a53b8099c1..2fcec5dbc6d 100644 --- a/models/experimental/mistral/tests/test_mistral_feed_forward.py +++ b/models/experimental/mistral/tests/test_mistral_feed_forward.py @@ -27,7 +27,7 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, dtype, reset_seeds): mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") state_dict = torch.load(mistral_path / "consolidated.00.pth") - base_address = f"" + base_address = f"layers.0.feed_forward." with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) @@ -37,11 +37,13 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d reference_model = FeedForward(args=model_args) reference_model.load_state_dict(state_dict) + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" + tt_model = TtFeedForward( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) input = torch.rand(1, 11, 4096) reference_output = reference_model(input) diff --git a/models/experimental/mistral/tests/test_mistral_rms_norm.py b/models/experimental/mistral/tests/test_mistral_rms_norm.py index d685b7ebe2c..a87025bbad9 100644 --- a/models/experimental/mistral/tests/test_mistral_rms_norm.py +++ b/models/experimental/mistral/tests/test_mistral_rms_norm.py @@ -23,7 +23,7 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset_seeds): mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") state_dict = torch.load(mistral_path / "consolidated.00.pth") - base_address = f"" + base_address = f"layers.0.attention_norm." with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) @@ -34,11 +34,11 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset reference_model = RMSNorm(dim=dim) reference_model.load_state_dict(state_dict) + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtRMSNorm( dim=dim, - state_dict=state_dict, - device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) input = torch.rand(1, 11, 4096) reference_output = reference_model(input) diff --git a/models/experimental/mistral/tests/test_mistral_transformer.py b/models/experimental/mistral/tests/test_mistral_transformer.py index 07422437a0d..b1818ac481f 100644 --- a/models/experimental/mistral/tests/test_mistral_transformer.py +++ b/models/experimental/mistral/tests/test_mistral_transformer.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +import tt_lib import torch import tt_lib import pytest @@ -38,17 +39,21 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) + model_args.WEIGHTS_DTYPE = dtype model_args.max_batch_size = 1 model_args.n_layers = 32 model_args.WEIGHTS_DTYPE = dtype reference_model = Transformer(args=model_args) reference_model.load_state_dict(state_dict) + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" + tt_model = TtTransformer( args=model_args, state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts] diff --git a/models/experimental/mistral/tests/test_mistral_transformer_block.py b/models/experimental/mistral/tests/test_mistral_transformer_block.py index 7499497b638..dd5a714b38c 100644 --- a/models/experimental/mistral/tests/test_mistral_transformer_block.py +++ b/models/experimental/mistral/tests/test_mistral_transformer_block.py @@ -28,7 +28,7 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, device, dtype, reset_seeds): mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") state_dict = torch.load(mistral_path / "consolidated.00.pth") - base_address = f"" + base_address = f"layers.0." with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) @@ -39,11 +39,13 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi reference_model = TransformerBlock(args=model_args) reference_model.load_state_dict(state_dict) + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" + tt_model = TtTransformerBlock( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) input = torch.randn(1, 11, 4096) diff --git a/models/experimental/mistral/tests/test_perf_mistral.py b/models/experimental/mistral/tests/test_perf_mistral.py index a4e98b82a06..111904a2979 100644 --- a/models/experimental/mistral/tests/test_perf_mistral.py +++ b/models/experimental/mistral/tests/test_perf_mistral.py @@ -55,11 +55,13 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod Path(mistral_path), n_layers=32, max_batch_size=max_batch_size, is_whole_model=True ) + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtTransformer( args=model_args, state_dict=state_dict, device=device, base_address=base_address, + tt_cache_path=tt_cache_path, ) with torch.no_grad(): diff --git a/models/experimental/mistral/tt/mistral_attention.py b/models/experimental/mistral/tt/mistral_attention.py index 8e6530e6648..2c0c3ae15bd 100644 --- a/models/experimental/mistral/tt/mistral_attention.py +++ b/models/experimental/mistral/tt/mistral_attention.py @@ -19,13 +19,12 @@ def __init__( args: TtModelArgs, base_address=None, device=None, - state_dict=None, + tt_cache_path=None, ): super().__init__() self.args = args self.device = device self.base_address = base_address - self.state_dict = state_dict self.n_heads: int = args.n_heads self.n_kv_heads: int = args.n_kv_heads @@ -35,13 +34,8 @@ def __init__( self.scale = self.args.head_dim**-0.5 - wq_weights = self.state_dict[f"{base_address}wq.weight"] - ref_wq_weights = wq_weights.unsqueeze(0).unsqueeze(0) - self.wq_weights = tt_lib.tensor.Tensor( - ref_wq_weights.reshape(-1).tolist(), - ref_wq_weights.shape, - self.args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.wq_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "wq.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.wq = TtLinear( args.dim, @@ -51,13 +45,8 @@ def __init__( output_mem_config=self.args.out_mem_config, ) - wk_weights = self.state_dict[f"{base_address}wk.weight"] - ref_wk_weights = wk_weights.unsqueeze(0).unsqueeze(0) - self.wk_weights = tt_lib.tensor.Tensor( - ref_wk_weights.reshape(-1).tolist(), - ref_wk_weights.shape, - self.args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.wk_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "wk.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.wk = TtLinear( args.dim, @@ -67,13 +56,8 @@ def __init__( output_mem_config=self.args.out_mem_config, ) - wv_weights = self.state_dict[f"{base_address}wv.weight"] - ref_wv_weights = wv_weights.unsqueeze(0).unsqueeze(0) - self.wv_weights = tt_lib.tensor.Tensor( - ref_wv_weights.reshape(-1).tolist(), - ref_wv_weights.shape, - self.args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.wv_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "wv.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.wv = TtLinear( args.dim, @@ -83,13 +67,8 @@ def __init__( output_mem_config=self.args.out_mem_config, ) - wo_weights = state_dict[f"{base_address}wo.weight"] - ref_wo_weights = wo_weights.unsqueeze(0).unsqueeze(0) - self.wo_weights = tt_lib.tensor.Tensor( - ref_wo_weights.reshape(-1).tolist(), - ref_wo_weights.shape, - self.args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.wo_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "wo.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.wo = TtLinear( args.n_heads * args.head_dim, @@ -141,11 +120,11 @@ def forward( _, bsz, seqlen, _ = x.shape() xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = fallback_ops.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim) + xq = tt_lib.tensor.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim) - xk = fallback_ops.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xk = tt_lib.tensor.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim) - xv = fallback_ops.reshape(xv, bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xv = tt_lib.tensor.reshape(xv, bsz, seqlen, self.n_kv_heads, self.args.head_dim) xq = tt_to_torch_tensor(xq).to(torch.float32) xk = tt_to_torch_tensor(xk).to(torch.float32) @@ -203,7 +182,6 @@ def forward( mask = mask.squeeze() scores += mask[None, None, ...] - query = tt_to_torch_tensor(query) scores = torch_to_tt_tensor_rm(scores, self.device, put_on_device=False) if self.args.FALLBACK_SOFTMAX: diff --git a/models/experimental/mistral/tt/mistral_feed_forward.py b/models/experimental/mistral/tt/mistral_feed_forward.py index 2a39897a424..8ba8ce92b09 100644 --- a/models/experimental/mistral/tt/mistral_feed_forward.py +++ b/models/experimental/mistral/tt/mistral_feed_forward.py @@ -15,18 +15,14 @@ def __init__( args: TtModelArgs, base_address=None, device=None, - state_dict=None, + tt_cache_path=None, ): super().__init__() self.device = device + self.args = args - w1_weights = state_dict[f"{base_address}w1.weight"] - ref_w1_weights = w1_weights.unsqueeze(0).unsqueeze(0) - self.w1_weights = tt_lib.tensor.Tensor( - ref_w1_weights.reshape(-1).tolist(), - ref_w1_weights.shape, - args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.w1_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "w1.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.w1 = TtLinear( args.dim, @@ -35,13 +31,8 @@ def __init__( device=self.device, ) - w2_weights = state_dict[f"{base_address}w2.weight"] - ref_w2_weights = w2_weights.unsqueeze(0).unsqueeze(0) - self.w2_weights = tt_lib.tensor.Tensor( - ref_w2_weights.reshape(-1).tolist(), - ref_w2_weights.shape, - args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.w2_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "w2.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.w2 = TtLinear( args.hidden_dim, @@ -50,13 +41,8 @@ def __init__( device=self.device, ) - w3_weights = state_dict[f"{base_address}w3.weight"] - ref_w3_weights = w3_weights.unsqueeze(0).unsqueeze(0) - self.w3_weights = tt_lib.tensor.Tensor( - ref_w3_weights.reshape(-1).tolist(), - ref_w3_weights.shape, - args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.w3_weights = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + "w3.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.w3 = TtLinear( args.dim, diff --git a/models/experimental/mistral/tt/mistral_rms_norm.py b/models/experimental/mistral/tt/mistral_rms_norm.py index 3b8f86b3acb..c323e434e13 100644 --- a/models/experimental/mistral/tt/mistral_rms_norm.py +++ b/models/experimental/mistral/tt/mistral_rms_norm.py @@ -11,23 +11,14 @@ def __init__( self, dim: int, eps: float = 1e-6, - state_dict=None, - device=None, base_address=None, + tt_cache_path=None, ): super().__init__() self.eps = eps - self.device = device - weight = state_dict[f"{base_address}weight"] - # converting to bfp8 reduces PCC - ref_weight = weight.unsqueeze(0).unsqueeze(0).unsqueeze(0) - self.weight = tt_lib.tensor.Tensor( - ref_weight.reshape(-1).tolist(), - ref_weight.shape, - tt_lib.tensor.DataType.BFLOAT16, - tt_lib.tensor.Layout.ROW_MAJOR, - ) + # bfp8 reduces PCC for so using weights in bfloat16 + self.weight = tt_lib.tensor.load_tensor(tt_cache_path + base_address + "weightDataType.BFLOAT16.bin") def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: return tt_lib.tensor.rmsnorm(x, self.eps, self.weight) diff --git a/models/experimental/mistral/tt/mistral_transformer.py b/models/experimental/mistral/tt/mistral_transformer.py index 81ff89d176b..783442a2d58 100644 --- a/models/experimental/mistral/tt/mistral_transformer.py +++ b/models/experimental/mistral/tt/mistral_transformer.py @@ -27,6 +27,7 @@ def __init__( device=None, state_dict=None, base_address=None, + tt_cache_path=None, ): super().__init__() self.args = args @@ -43,19 +44,23 @@ def __init__( self.layers = torch.nn.ModuleList( [ TtTransformerBlock( - args=args, state_dict=self.state_dict, base_address=f"layers.{i}.", device=self.device + args=args, + base_address=f"layers.{i}.", + device=self.device, + tt_cache_path=tt_cache_path, ) for i in range(args.n_layers) ] ) - self.norm = TtRMSNorm(args.dim, base_address=f"norm.", state_dict=state_dict, device=device, eps=args.norm_eps) - output_weight = state_dict["output.weight"] - ref_output_weight = output_weight.unsqueeze(0).unsqueeze(0) - self.output_weight = tt_lib.tensor.Tensor( - ref_output_weight.reshape(-1).tolist(), - ref_output_weight.shape, - args.WEIGHTS_DTYPE, - tt_lib.tensor.Layout.ROW_MAJOR, + self.norm = TtRMSNorm( + args.dim, + base_address=f"norm.", + eps=args.norm_eps, + tt_cache_path=tt_cache_path, + ) + + self.output_weight = tt_lib.tensor.load_tensor( + tt_cache_path + "output.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin" ) self.output = TtLinear( args.dim, diff --git a/models/experimental/mistral/tt/mistral_transformer_block.py b/models/experimental/mistral/tt/mistral_transformer_block.py index 788ba1d7afb..2a3cce715ec 100644 --- a/models/experimental/mistral/tt/mistral_transformer_block.py +++ b/models/experimental/mistral/tt/mistral_transformer_block.py @@ -15,25 +15,27 @@ class TtTransformerBlock(nn.Module): def __init__( self, args: TtModelArgs, - state_dict=None, device=None, base_address=None, + tt_cache_path=None, ): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.device = device - self.attention = TtAttention(args, f"{base_address}attention.", device, state_dict) - self.feed_forward = TtFeedForward(args, f"{base_address}feed_forward.", device, state_dict) + self.attention = TtAttention(args, f"{base_address}attention.", device, tt_cache_path=tt_cache_path) + self.feed_forward = TtFeedForward(args, f"{base_address}feed_forward.", device, tt_cache_path=tt_cache_path) self.attention_norm = TtRMSNorm( args.dim, base_address=f"{base_address}attention_norm.", - state_dict=state_dict, - device=device, eps=args.norm_eps, + tt_cache_path=tt_cache_path, ) self.ffn_norm = TtRMSNorm( - args.dim, base_address=f"{base_address}ffn_norm.", state_dict=state_dict, device=device, eps=args.norm_eps + args.dim, + base_address=f"{base_address}ffn_norm.", + eps=args.norm_eps, + tt_cache_path=tt_cache_path, ) self.args = args