Skip to content

Commit

Permalink
#3824: cache weight tensors for mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Nov 24, 2023
1 parent 21b2726 commit 8106601
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 92 deletions.
2 changes: 2 additions & 0 deletions models/experimental/mistral/demo/gs_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def test_gs_demo_single_input_inference(batch_size, model_location_generator, de
model_args.max_batch_size = batch_size
model_args.n_layers = 32

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)

tt_output = generate(
Expand Down
36 changes: 36 additions & 0 deletions models/experimental/mistral/mistral_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from sentencepiece import SentencePieceProcessor
from pathlib import Path
from models.utility_functions import tt_to_torch_tensor
import tt_lib
import json
from models.experimental.mistral.tt.mistral_configuration import TtModelArgs
from tt_lib.utils import pad_weight


class Tokenizer:
Expand Down Expand Up @@ -76,3 +80,35 @@ def generate(prompts: List[str], model: TtTransformer, tokenizer: Tokenizer, max
for i, x in enumerate(encoded_prompts):
res.append(tokenizer.decode(x[:min_prompt_len] + generated[i].tolist()))
return res


def cache_weights_in_weka(model_location_generator, device, dtype, reset_seeds):
mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
state_dict = torch.load(mistral_path / "consolidated.00.pth")
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))
weights_dtype = dtype

# initial weights are stored in "models/experimental/mistral/weights/" and moved to weka path
file_name = "models/experimental/mistral/weights/"
for key, value in state_dict.items():
if len(value.shape) == 1:
value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0)
else:
value = value.unsqueeze(0).unsqueeze(0)
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0:
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
else:
value = pad_weight(value)
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value)
5 changes: 3 additions & 2 deletions models/experimental/mistral/tests/test_mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_mistral_attention_inference(
):
mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
state_dict = torch.load(mistral_path / "consolidated.00.pth")
base_address = f""
base_address = f"layers.0.attention."
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))
if True:
Expand All @@ -68,11 +68,12 @@ def test_mistral_attention_inference(
model_args.FALLBACK_EMPTY = empty_ondevice
model_args.FALLBACK_SCATTER = scatter_ondevice
model_args.WEIGHTS_DTYPE = dtype
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtAttention(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)
input = torch.randn(1, 11, 4096)
empty_tensor = torch.zeros((11, 64))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def test_mistral_feed_forward_inference(pcc, model_location_generator, device, dtype, reset_seeds):
mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
state_dict = torch.load(mistral_path / "consolidated.00.pth")
base_address = f""
base_address = f"layers.0.feed_forward."
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))

Expand All @@ -37,11 +37,13 @@ def test_mistral_feed_forward_inference(pcc, model_location_generator, device, d
reference_model = FeedForward(args=model_args)
reference_model.load_state_dict(state_dict)

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtFeedForward(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
6 changes: 3 additions & 3 deletions models/experimental/mistral/tests/test_mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset_seeds):
mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
state_dict = torch.load(mistral_path / "consolidated.00.pth")
base_address = f""
base_address = f"layers.0.attention_norm."
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))

Expand All @@ -34,11 +34,11 @@ def test_mistral_rms_norm_inference(pcc, model_location_generator, device, reset
reference_model = RMSNorm(dim=dim)
reference_model.load_state_dict(state_dict)

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtRMSNorm(
dim=dim,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)
input = torch.rand(1, 11, 4096)
reference_output = reference_model(input)
Expand Down
5 changes: 5 additions & 0 deletions models/experimental/mistral/tests/test_mistral_transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0
import tt_lib
import torch
import tt_lib
import pytest
Expand Down Expand Up @@ -38,17 +39,21 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))

model_args.WEIGHTS_DTYPE = dtype
model_args.max_batch_size = 1
model_args.n_layers = 32
model_args.WEIGHTS_DTYPE = dtype
reference_model = Transformer(args=model_args)
reference_model.load_state_dict(state_dict)

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)

encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def test_mistral_transformer_block_inference(pcc, model_location_generator, device, dtype, reset_seeds):
mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
state_dict = torch.load(mistral_path / "consolidated.00.pth")
base_address = f""
base_address = f"layers.0."
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))

Expand All @@ -39,11 +39,13 @@ def test_mistral_transformer_block_inference(pcc, model_location_generator, devi
reference_model = TransformerBlock(args=model_args)
reference_model.load_state_dict(state_dict)

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"

tt_model = TtTransformerBlock(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)

input = torch.randn(1, 11, 4096)
Expand Down
2 changes: 2 additions & 0 deletions models/experimental/mistral/tests/test_perf_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod
Path(mistral_path), n_layers=32, max_batch_size=max_batch_size, is_whole_model=True
)

tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
)

with torch.no_grad():
Expand Down
46 changes: 12 additions & 34 deletions models/experimental/mistral/tt/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ def __init__(
args: TtModelArgs,
base_address=None,
device=None,
state_dict=None,
tt_cache_path=None,
):
super().__init__()
self.args = args
self.device = device
self.base_address = base_address
self.state_dict = state_dict

self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
Expand All @@ -35,13 +34,8 @@ def __init__(

self.scale = self.args.head_dim**-0.5

wq_weights = self.state_dict[f"{base_address}wq.weight"]
ref_wq_weights = wq_weights.unsqueeze(0).unsqueeze(0)
self.wq_weights = tt_lib.tensor.Tensor(
ref_wq_weights.reshape(-1).tolist(),
ref_wq_weights.shape,
self.args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.wq_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "wq.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.wq = TtLinear(
args.dim,
Expand All @@ -51,13 +45,8 @@ def __init__(
output_mem_config=self.args.out_mem_config,
)

wk_weights = self.state_dict[f"{base_address}wk.weight"]
ref_wk_weights = wk_weights.unsqueeze(0).unsqueeze(0)
self.wk_weights = tt_lib.tensor.Tensor(
ref_wk_weights.reshape(-1).tolist(),
ref_wk_weights.shape,
self.args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.wk_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "wk.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.wk = TtLinear(
args.dim,
Expand All @@ -67,13 +56,8 @@ def __init__(
output_mem_config=self.args.out_mem_config,
)

wv_weights = self.state_dict[f"{base_address}wv.weight"]
ref_wv_weights = wv_weights.unsqueeze(0).unsqueeze(0)
self.wv_weights = tt_lib.tensor.Tensor(
ref_wv_weights.reshape(-1).tolist(),
ref_wv_weights.shape,
self.args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.wv_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "wv.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.wv = TtLinear(
args.dim,
Expand All @@ -83,13 +67,8 @@ def __init__(
output_mem_config=self.args.out_mem_config,
)

wo_weights = state_dict[f"{base_address}wo.weight"]
ref_wo_weights = wo_weights.unsqueeze(0).unsqueeze(0)
self.wo_weights = tt_lib.tensor.Tensor(
ref_wo_weights.reshape(-1).tolist(),
ref_wo_weights.shape,
self.args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.wo_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "wo.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.wo = TtLinear(
args.n_heads * args.head_dim,
Expand Down Expand Up @@ -141,11 +120,11 @@ def forward(
_, bsz, seqlen, _ = x.shape()
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = fallback_ops.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim)
xq = tt_lib.tensor.reshape(xq, bsz, seqlen, self.n_heads, self.args.head_dim)

xk = fallback_ops.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim)
xk = tt_lib.tensor.reshape(xk, bsz, seqlen, self.n_kv_heads, self.args.head_dim)

xv = fallback_ops.reshape(xv, bsz, seqlen, self.n_kv_heads, self.args.head_dim)
xv = tt_lib.tensor.reshape(xv, bsz, seqlen, self.n_kv_heads, self.args.head_dim)

xq = tt_to_torch_tensor(xq).to(torch.float32)
xk = tt_to_torch_tensor(xk).to(torch.float32)
Expand Down Expand Up @@ -203,7 +182,6 @@ def forward(
mask = mask.squeeze()
scores += mask[None, None, ...]

query = tt_to_torch_tensor(query)
scores = torch_to_tt_tensor_rm(scores, self.device, put_on_device=False)

if self.args.FALLBACK_SOFTMAX:
Expand Down
30 changes: 8 additions & 22 deletions models/experimental/mistral/tt/mistral_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@ def __init__(
args: TtModelArgs,
base_address=None,
device=None,
state_dict=None,
tt_cache_path=None,
):
super().__init__()
self.device = device
self.args = args

w1_weights = state_dict[f"{base_address}w1.weight"]
ref_w1_weights = w1_weights.unsqueeze(0).unsqueeze(0)
self.w1_weights = tt_lib.tensor.Tensor(
ref_w1_weights.reshape(-1).tolist(),
ref_w1_weights.shape,
args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.w1_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "w1.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.w1 = TtLinear(
args.dim,
Expand All @@ -35,13 +31,8 @@ def __init__(
device=self.device,
)

w2_weights = state_dict[f"{base_address}w2.weight"]
ref_w2_weights = w2_weights.unsqueeze(0).unsqueeze(0)
self.w2_weights = tt_lib.tensor.Tensor(
ref_w2_weights.reshape(-1).tolist(),
ref_w2_weights.shape,
args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.w2_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "w2.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.w2 = TtLinear(
args.hidden_dim,
Expand All @@ -50,13 +41,8 @@ def __init__(
device=self.device,
)

w3_weights = state_dict[f"{base_address}w3.weight"]
ref_w3_weights = w3_weights.unsqueeze(0).unsqueeze(0)
self.w3_weights = tt_lib.tensor.Tensor(
ref_w3_weights.reshape(-1).tolist(),
ref_w3_weights.shape,
args.WEIGHTS_DTYPE,
tt_lib.tensor.Layout.ROW_MAJOR,
self.w3_weights = tt_lib.tensor.load_tensor(
tt_cache_path + base_address + "w3.weight" + str(self.args.WEIGHTS_DTYPE) + ".bin"
)
self.w3 = TtLinear(
args.dim,
Expand Down
15 changes: 3 additions & 12 deletions models/experimental/mistral/tt/mistral_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,14 @@ def __init__(
self,
dim: int,
eps: float = 1e-6,
state_dict=None,
device=None,
base_address=None,
tt_cache_path=None,
):
super().__init__()
self.eps = eps
self.device = device
weight = state_dict[f"{base_address}weight"]

# converting to bfp8 reduces PCC
ref_weight = weight.unsqueeze(0).unsqueeze(0).unsqueeze(0)
self.weight = tt_lib.tensor.Tensor(
ref_weight.reshape(-1).tolist(),
ref_weight.shape,
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
)
# bfp8 reduces PCC for so using weights in bfloat16
self.weight = tt_lib.tensor.load_tensor(tt_cache_path + base_address + "weightDataType.BFLOAT16.bin")

def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor:
return tt_lib.tensor.rmsnorm(x, self.eps, self.weight)
Loading

0 comments on commit 8106601

Please sign in to comment.