diff --git a/models/experimental/nanogpt/nanogpt_utils.py b/models/experimental/nanogpt/nanogpt_utils.py new file mode 100644 index 00000000000..aa221481837 --- /dev/null +++ b/models/experimental/nanogpt/nanogpt_utils.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from models.utility_functions import tt2torch_tensor +import torch +import tt_lib +from transformers import GPT2LMHeadModel +from tt_lib.utils import pad_weight +from pathlib import Path +import os + + +def unpad_from_zero(x, desired_shape): + if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]: + x = tt2torch_tensor(x) + else: + x = x.cpu() + if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR: + x = x.to(tt_lib.tensor.Layout.ROW_MAJOR) + x = x.unpad( + (0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1) + ) + x = x.to_torch().to(torch.float) + return x + + +def cache_weights_in_weka(device, dtype, reset_seeds): + model_hf = GPT2LMHeadModel.from_pretrained("gpt2") + state_dict = model_hf.state_dict() + weights_dtype = dtype + + # initial weights are stored in "models/experimental/nanogpt/weights/" and moved to weka path + file_name = "models/experimental/nanogpt/weights/" + for key, value in state_dict.items(): + if key.startswith("transformer.wte.") or key.startswith("transformer.wpe."): + torch.save(value, file_name + str(key) + ".pt") + continue + elif len(value.shape) == 0: + continue + while len(value.shape) < 4: + value = value.unsqueeze(0) + if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + else: + value = pad_weight(value) + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) + + +"""This function will load weights from the state_dict and check if the needed weights are available in given path. +If they are not available, it will convert torch tensor weights to TT tensor weights and store them in the given path.""" + + +def store_weights(model_version, file_name, base_address, dtype): + model_hf = GPT2LMHeadModel.from_pretrained(model_version) + state_dict = model_hf.state_dict() + weights_dtype = dtype + + for key, value in state_dict.items(): + if base_address == "" and ( + (key.startswith("transformer.wte.") and os.path.exists(file_name + str(key) + ".pt") == False) + or (key.startswith("transformer.wpe.") and os.path.exists(file_name + str(key) + ".pt") == False) + ): + torch.save(value, file_name + str(key) + ".pt") + continue + if key.startswith("transformer.wte.") or key.startswith("transformer.wpe.") or (len(value.shape) == 0): + continue + if (os.path.exists(file_name + str(key) + str(weights_dtype) + ".bin")) or ( + key.startswith(base_address) == False and base_address != "" + ): + continue + while len(value.shape) < 4: + value = value.unsqueeze(0) + if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + else: + value = pad_weight(value) + value = tt_lib.tensor.Tensor( + value.reshape(-1).tolist(), + value.shape, + weights_dtype, + tt_lib.tensor.Layout.ROW_MAJOR, + ).to(tt_lib.tensor.Layout.TILE) + tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) + + +def get_tt_cache_path(model_version): + tt_cache_path = Path("/mnt/MLPerf/tt_dnn-models/tt/NanoGPT") / model_version + if tt_cache_path.exists(): + return str(tt_cache_path) + "/" + else: + Path(f"models/experimental/nanogpt/datasets/{model_version}").mkdir(parents=True, exist_ok=True) + return str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/" diff --git a/models/experimental/nanogpt/tests/test_nanogpt_attention.py b/models/experimental/nanogpt/tests/test_nanogpt_attention.py index 67c3569d22c..11c2a8f1abb 100644 --- a/models/experimental/nanogpt/tests/test_nanogpt_attention.py +++ b/models/experimental/nanogpt/tests/test_nanogpt_attention.py @@ -4,12 +4,15 @@ import torch import pytest +import tt_lib +import os +from pathlib import Path from transformers import GPT2LMHeadModel - from loguru import logger import models.experimental.nanogpt.tt.nanogpt_attention as nanogpt_attention +from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights from models.utility_functions import ( tt_to_torch_tensor, @@ -19,16 +22,17 @@ ) +@pytest.mark.parametrize( + "dtype", + (tt_lib.tensor.DataType.BFLOAT16,), +) @pytest.mark.parametrize( "pcc", ((0.99,),), ) - -def test_nanogpt_attn(device, pcc, reset_seeds): - +def test_nanogpt_attn(device, pcc, dtype, reset_seeds): # Prepare input model_hf = GPT2LMHeadModel.from_pretrained("gpt2") - sd = model_hf.state_dict() config = model_hf.config model_hf.eval() block = 0 @@ -38,8 +42,17 @@ def test_nanogpt_attn(device, pcc, reset_seeds): pt_attn = model_hf.transformer.h[block].attn pt_out = pt_attn.forward(test_in) + model_version = "gpt2" + tt_cache_path = get_tt_cache_path(model_version) + + if ( + tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/") + and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320 + ): + store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address) + tt_test_in = torch_to_tt_tensor_rm(test_in, device) - tt_attn = nanogpt_attention.TtCausalSelfAttention(config, sd, base_address, device) + tt_attn = nanogpt_attention.TtCausalSelfAttention(config, base_address, device, tt_cache_path, dtype) tt_out = tt_attn.forward(tt_test_in) diff --git a/models/experimental/nanogpt/tests/test_nanogpt_block.py b/models/experimental/nanogpt/tests/test_nanogpt_block.py index e22b0ab0187..036a5709ed1 100644 --- a/models/experimental/nanogpt/tests/test_nanogpt_block.py +++ b/models/experimental/nanogpt/tests/test_nanogpt_block.py @@ -4,11 +4,15 @@ import torch import pytest +import tt_lib +from pathlib import Path +import os from transformers import GPT2LMHeadModel from loguru import logger import models.experimental.nanogpt.tt.nanogpt_block as nanogpt_block +from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights from models.utility_functions import ( tt_to_torch_tensor, @@ -18,14 +22,16 @@ ) +@pytest.mark.parametrize( + "dtype", + (tt_lib.tensor.DataType.BFLOAT16,), +) @pytest.mark.parametrize( "pcc", ((0.99,),), ) -def test_nanogpt_block(device, pcc, reset_seeds): - +def test_nanogpt_block(device, pcc, dtype, reset_seeds): model_hf = GPT2LMHeadModel.from_pretrained("gpt2") - sd = model_hf.state_dict() config = model_hf.config model_hf.eval() block = 0 @@ -36,8 +42,16 @@ def test_nanogpt_block(device, pcc, reset_seeds): pt_out = pt_block.forward(test_in) tt_test_in = torch_to_tt_tensor_rm(test_in, device) + model_version = "gpt2" + tt_cache_path = get_tt_cache_path(model_version) + + if ( + tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/") + and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320 + ): + store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address) - tt_block = nanogpt_block.TtBlock(config, sd, base_address, device) + tt_block = nanogpt_block.TtBlock(config, base_address, device, tt_cache_path, dtype) tt_block.eval() tt_out = tt_block.forward(tt_test_in) diff --git a/models/experimental/nanogpt/tests/test_nanogpt_mlp.py b/models/experimental/nanogpt/tests/test_nanogpt_mlp.py index 466f102e575..90097d42447 100644 --- a/models/experimental/nanogpt/tests/test_nanogpt_mlp.py +++ b/models/experimental/nanogpt/tests/test_nanogpt_mlp.py @@ -4,6 +4,10 @@ import torch import pytest +import tt_lib +from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights +from pathlib import Path +import os from transformers import GPT2LMHeadModel @@ -19,14 +23,16 @@ ) +@pytest.mark.parametrize( + "dtype", + (tt_lib.tensor.DataType.BFLOAT16,), +) @pytest.mark.parametrize( "pcc", ((0.99,),), ) -def test_nanogpt_mlp(device, pcc, reset_seeds): - +def test_nanogpt_mlp(device, pcc, dtype, reset_seeds): model_hf = GPT2LMHeadModel.from_pretrained("gpt2") - sd = model_hf.state_dict() config = model_hf.config model_hf.eval() block = 0 @@ -34,7 +40,16 @@ def test_nanogpt_mlp(device, pcc, reset_seeds): test_in = torch.rand(1, 43, 768) tt_test_in = torch_to_tt_tensor_rm(test_in, device) - tt_mlp = nanogpt_mlp.TtMLP(base_address, config, sd, device) + model_version = "gpt2" + tt_cache_path = get_tt_cache_path(model_version) + + if ( + tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/") + and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320 + ): + store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address) + + tt_mlp = nanogpt_mlp.TtMLP(base_address, config, device, tt_cache_path, dtype) tt_out = tt_mlp.forward(tt_test_in) diff --git a/models/experimental/nanogpt/tests/test_nanogpt_model_real.py b/models/experimental/nanogpt/tests/test_nanogpt_model.py similarity index 57% rename from models/experimental/nanogpt/tests/test_nanogpt_model_real.py rename to models/experimental/nanogpt/tests/test_nanogpt_model.py index 0217cfa58da..b5abb501c0b 100644 --- a/models/experimental/nanogpt/tests/test_nanogpt_model_real.py +++ b/models/experimental/nanogpt/tests/test_nanogpt_model.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 -import torch +import tt_lib import pytest from transformers import GPT2Tokenizer, GPT2LMHeadModel +from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights +from pathlib import Path +import os from loguru import logger import models.experimental.nanogpt.tt.nanogpt_model as nanogpt_model @@ -13,17 +16,18 @@ from models.utility_functions import tt_to_torch_tensor, comp_allclose, comp_pcc - +@pytest.mark.parametrize( + "dtype", + (tt_lib.tensor.DataType.BFLOAT16,), +) @pytest.mark.parametrize( "pcc, prompt", - ((0.99, "Hello, my dog is a little"),), + ((0.98, "Hello, my dog is a little"),), ) -def test_nanogpt_model_real(device, pcc, prompt, reset_seeds): - +def test_nanogpt_model_real(device, pcc, prompt, dtype, reset_seeds): # Prepare input model_hf = GPT2LMHeadModel.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - sd = model_hf.state_dict() model_hf.eval() inputs = tokenizer(prompt, return_tensors="pt", padding=False) @@ -33,7 +37,17 @@ def test_nanogpt_model_real(device, pcc, prompt, reset_seeds): config = model_hf.config - tt_model = nanogpt_model.TtGPT(config, sd, device) + base_address = "" + model_version = "gpt2" + tt_cache_path = get_tt_cache_path(model_version) + + if ( + tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/") + and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320 + ): + store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address) + + tt_model = nanogpt_model.TtGPT(config, device, tt_cache_path, dtype) tt_out = tt_model.forward(inputs.input_ids) diff --git a/models/experimental/nanogpt/tt/nanogpt.py b/models/experimental/nanogpt/tt/nanogpt.py new file mode 100644 index 00000000000..7b2a0ca95af --- /dev/null +++ b/models/experimental/nanogpt/tt/nanogpt.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from transformers import GPT2LMHeadModel +from models.experimental.nanogpt.tt.nanogpt_model import TtGPT + + +def _nanogpt(config, device, tt_cache_path, dtype): + return TtGPT( + config=config, + device=device, + tt_cache_path=tt_cache_path, + dtype=dtype, + ) + + +def nanogpt_model(device, dtype) -> TtGPT: + model_name = "gpt2" + model = GPT2LMHeadModel.from_pretrained(model_name) + config = model.config + tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/NanoGPT/" + model = _nanogpt(config, device, tt_cache_path, dtype) + return model diff --git a/models/experimental/nanogpt/tt/nanogpt_attention.py b/models/experimental/nanogpt/tt/nanogpt_attention.py index eb37adf6624..57f148ba097 100644 --- a/models/experimental/nanogpt/tt/nanogpt_attention.py +++ b/models/experimental/nanogpt/tt/nanogpt_attention.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import torch import torch.nn as nn import tt_lib import math @@ -16,7 +15,7 @@ class TtCausalSelfAttention(nn.Module): - def __init__(self, config, state_dict, base_address, device): + def __init__(self, config, base_address, device, tt_cache_path, dtype): super().__init__() assert config.n_embd % config.n_head == 0 @@ -25,28 +24,34 @@ def __init__(self, config, state_dict, base_address, device): self.device = device # Get the weights - self.tt_weight_c_attn = state_dict[f"{base_address}.c_attn.weight"] - self.tt_weight_c_proj = state_dict[f"{base_address}.c_proj.weight"] - - # Push weights to Ttp device - self.tt_weight_c_attn = torch_to_tt_tensor_rm(self.tt_weight_c_attn, self.device) + self.tt_weight_c_attn = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_attn.weight" + str(dtype) + ".bin" + ) - self.tt_weight_c_proj = torch_to_tt_tensor_rm(self.tt_weight_c_proj, self.device) + self.tt_weight_c_proj = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_proj.weight" + str(dtype) + ".bin" + ) self.tt_weight_c_attn = tt_lib.tensor.transpose(self.tt_weight_c_attn, -2, -1) self.tt_weight_c_proj = tt_lib.tensor.transpose(self.tt_weight_c_proj, -2, -1) # Load biases - self.tt_bias_c_attn = torch_to_tt_tensor_rm(state_dict[f"{base_address}.c_attn.bias"], self.device) + self.tt_bias_c_attn = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_attn.bias" + str(dtype) + ".bin" + ) - self.tt_bias_c_proj = torch_to_tt_tensor_rm(state_dict[f"{base_address}.c_proj.bias"], self.device) + self.tt_bias_c_proj = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_proj.bias" + str(dtype) + ".bin" + ) self.n_head = self.config.n_head self.n_embd = self.config.n_embd + temp_bias = tt_lib.tensor.tril(tt_lib.tensor.ones([1, 1, self.block_size, self.block_size])) + temp_bias = tt_to_torch_tensor(temp_bias) self.register_buffer( "bias", - torch.tril(torch.ones(self.block_size, self.block_size)).view(1, 1, self.block_size, self.block_size), + temp_bias, ) self.c_attn = Linear( @@ -82,16 +87,16 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: q, k, v = pt_x1.split(self.n_embd, dim=2) k = torch_to_tt_tensor_rm(k, self.device) - k = fallback_ops.reshape(k, B, T, self.n_head, C // self.n_head) + k = tt_lib.tensor.reshape(k, B, T, self.n_head, C // self.n_head) k = tt_lib.tensor.transpose(k, 1, 2) q = torch_to_tt_tensor_rm(q, self.device) - q = fallback_ops.reshape(q, B, T, self.n_head, C // self.n_head) + q = tt_lib.tensor.reshape(q, B, T, self.n_head, C // self.n_head) q = tt_lib.tensor.transpose(q, 1, 2) v = torch_to_tt_tensor_rm(v, self.device) - v = fallback_ops.reshape(v, B, T, self.n_head, C // self.n_head) + v = tt_lib.tensor.reshape(v, B, T, self.n_head, C // self.n_head) v = tt_lib.tensor.transpose(v, 1, 2) # manual implementation of attention @@ -107,12 +112,14 @@ def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: tt_att = torch_to_tt_tensor_rm(att, self.device, put_on_device=False) - tt_att = fallback_ops.softmax(tt_att, dim=-1) + tt_att = tt_lib.tensor.softmax( + tt_att + ) # Using tt_lib.tensor.softmax reduces pcc from 0.99 to 0.98 for whole model tt_y = tt_lib.tensor.bmm(tt_att, v) tt_y = tt_lib.tensor.transpose(tt_y, 1, -2) - tt_y = fallback_ops.reshape(tt_y, 1, B, T, C) + tt_y = tt_lib.tensor.reshape(tt_y, 1, B, T, C) # output projection x2 = self.c_proj(tt_y) diff --git a/models/experimental/nanogpt/tt/nanogpt_block.py b/models/experimental/nanogpt/tt/nanogpt_block.py index 349b9f48c75..1606ab2b42d 100644 --- a/models/experimental/nanogpt/tt/nanogpt_block.py +++ b/models/experimental/nanogpt/tt/nanogpt_block.py @@ -8,55 +8,36 @@ import models.experimental.nanogpt.tt.nanogpt_attention as nanogpt_attention -from models.utility_functions import ( - torch_to_tt_tensor_rm, -) - - class TtBlock(nn.Module): - def __init__(self, config, state_dict, base_address, device): + def __init__(self, config, base_address, device, tt_cache_path, dtype): super().__init__() self.device = device self.config = config - self.beta_1 = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_1.bias"], self.device - ) + self.beta_1 = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_1.bias" + str(dtype) + ".bin") - self.gamma_1 = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_1.weight"], self.device - ) + self.gamma_1 = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_1.weight" + str(dtype) + ".bin") self.ln_1 = tt_lib.tensor.layernorm self.attn = nanogpt_attention.TtCausalSelfAttention( - config, state_dict, f"{base_address}.attn", device + config, f"{base_address}.attn", device, tt_cache_path, dtype ) - self.beta_2 = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_2.bias"], self.device - ) + self.beta_2 = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_2.bias" + str(dtype) + ".bin") - self.gamma_2 = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_2.weight"], self.device - ) + self.gamma_2 = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_2.weight" + str(dtype) + ".bin") self.ln_2 = tt_lib.tensor.layernorm - self.mlp = nanogpt_mlp.TtMLP( - f"{base_address}.mlp", self.config, state_dict, device - ) + self.mlp = nanogpt_mlp.TtMLP(f"{base_address}.mlp", self.config, device, tt_cache_path, dtype) def forward(self, x: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: - tmp = self.attn.forward( - self.ln_1(x, eps=1e-5, gamma=self.gamma_1, beta=self.beta_1) - ) + tmp = self.attn.forward(self.ln_1(x, eps=1e-5, gamma=self.gamma_1, beta=self.beta_1)) x = tt_lib.tensor.add(x, tmp) - tmp = self.mlp.forward( - self.ln_2(x, eps=1e-5, gamma=self.gamma_2, beta=self.beta_2) - ) + tmp = self.mlp.forward(self.ln_2(x, eps=1e-5, gamma=self.gamma_2, beta=self.beta_2)) x = tt_lib.tensor.add(x, tmp) return x diff --git a/models/experimental/nanogpt/tt/nanogpt_mlp.py b/models/experimental/nanogpt/tt/nanogpt_mlp.py index f97c5b49b9f..fa901c928b5 100644 --- a/models/experimental/nanogpt/tt/nanogpt_mlp.py +++ b/models/experimental/nanogpt/tt/nanogpt_mlp.py @@ -6,29 +6,27 @@ import tt_lib from models.helper_funcs import Linear -from models.utility_functions import ( - torch_to_tt_tensor_rm, -) - class TtMLP(torch.nn.Module): - def __init__(self, base_address, config, state_dict, device): + def __init__(self, base_address, config, device, tt_cache_path, dtype): super().__init__() # Get the weights - self.tt_weight_c_fc = state_dict[f"{base_address}.c_fc.weight"] - self.tt_weight_c_proj = state_dict[f"{base_address}.c_proj.weight"] + self.tt_weight_c_fc = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_fc.weight" + str(dtype) + ".bin" + ) + self.tt_weight_c_proj = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_proj.weight" + str(dtype) + ".bin" + ) + self.config = config self.device = device - # Push weights to Tt device - self.tt_weight_c_fc = torch_to_tt_tensor_rm(self.tt_weight_c_fc, self.device) - - self.tt_weight_c_proj = torch_to_tt_tensor_rm(self.tt_weight_c_proj, self.device) - # Load biases - self.tt_bias_c_fc = torch_to_tt_tensor_rm(state_dict[f"{base_address}.c_fc.bias"], self.device) + self.tt_bias_c_fc = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".c_fc.bias" + str(dtype) + ".bin") - self.tt_bias_c_proj = torch_to_tt_tensor_rm(state_dict[f"{base_address}.c_proj.bias"], self.device) + self.tt_bias_c_proj = tt_lib.tensor.load_tensor( + tt_cache_path + base_address + ".c_proj.bias" + str(dtype) + ".bin" + ) self.tt_weight_c_fc = tt_lib.tensor.transpose(self.tt_weight_c_fc, -2, -1) self.tt_weight_c_proj = tt_lib.tensor.transpose(self.tt_weight_c_proj, -2, -1) diff --git a/models/experimental/nanogpt/tt/nanogpt_model.py b/models/experimental/nanogpt/tt/nanogpt_model.py index 03d16aa654e..eabe070c7bd 100644 --- a/models/experimental/nanogpt/tt/nanogpt_model.py +++ b/models/experimental/nanogpt/tt/nanogpt_model.py @@ -6,16 +6,19 @@ import torch.nn as nn import tt_lib from models.helper_funcs import Linear +import tt_lib.fallback_ops as fallback_ops import models.experimental.nanogpt.tt.nanogpt_block as nanogpt_block +from models.experimental.nanogpt.nanogpt_utils import unpad_from_zero from models.utility_functions import ( torch_to_tt_tensor_rm, + tt_to_torch_tensor, ) class TtGPT(nn.Module): - def __init__(self, config, state_dict, device): + def __init__(self, config, device, tt_cache_path, dtype): super().__init__() assert config.vocab_size is not None @@ -24,48 +27,47 @@ def __init__(self, config, state_dict, device): self.config.block_size = 1024 base_address = f"transformer" self.device = device - self.beta = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_f.bias"], self.device - ) - self.gamma = torch_to_tt_tensor_rm( - state_dict[f"{base_address}.ln_f.weight"], self.device - ) + self.beta = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_f.bias" + str(dtype) + ".bin") + + self.gamma = tt_lib.tensor.load_tensor(tt_cache_path + base_address + ".ln_f.weight" + str(dtype) + ".bin") self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(self.config.block_size, config.n_embd) - self.wte.weight = torch.nn.Parameter(state_dict[f"{base_address}.wte.weight"]) + self.wte.weight = torch.nn.Parameter(torch.load(tt_cache_path + "transformer.wte.weight.pt")) - self.wpe.weight = torch.nn.Parameter(state_dict[f"{base_address}.wpe.weight"]) + self.wpe.weight = torch.nn.Parameter(torch.load(tt_cache_path + "transformer.wpe.weight.pt")) blocks = [] for i in range(config.n_layer): - block = nanogpt_block.TtBlock( - self.config, state_dict, f"{base_address}.h.{i}", self.device - ) + block = nanogpt_block.TtBlock(self.config, f"{base_address}.h.{i}", self.device, tt_cache_path, dtype) blocks.append(block) self.h = nn.ModuleList(blocks) self.ln_f = tt_lib.tensor.layernorm - # Push weights to Tt device - tt_lm_weight = torch_to_tt_tensor_rm(state_dict["lm_head.weight"], self.device) + tt_lm_weight = tt_lib.tensor.load_tensor(tt_cache_path + "lm_head.weight" + str(dtype) + ".bin") + + weight = unpad_from_zero(tt_lm_weight, (1, 1, self.config.vocab_size, self.config.n_embd)) + weight_torch = weight + weight = torch_to_tt_tensor_rm(weight, device=self.device) - self.lm_head = Linear(self.config.n_embd, self.config.vocab_size, tt_lm_weight) + self.lm_head = Linear(self.config.n_embd, self.config.vocab_size, weight) - self.wte.weight = nn.Parameter( - state_dict["lm_head.weight"] - ) # https://paperswithcode.com/method/weight-tying + self.wte.weight = nn.Parameter(weight_torch.squeeze()) # https://paperswithcode.com/method/weight-tying def forward(self, idx: torch.Tensor) -> tt_lib.tensor.Tensor: b, t = idx.shape assert ( t <= self.config.block_size ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = torch.arange(0, t, dtype=torch.long).unsqueeze(0) # shape (1, t) + pos = tt_lib.tensor.arange(0, t, 1) + pos = tt_to_torch_tensor(pos) + pos = pos.squeeze(0).squeeze(0) + pos = pos.to(dtype=torch.int64) # forward the GPT model itself tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) @@ -74,7 +76,13 @@ def forward(self, idx: torch.Tensor) -> tt_lib.tensor.Tensor: tt_tok_emb = torch_to_tt_tensor_rm(tok_emb, self.device) tt_pos_emb = torch_to_tt_tensor_rm(pos_emb, self.device) - tt_x = tt_lib.tensor.add(tt_tok_emb, tt_pos_emb) + tt_tok_emb = tt_lib.tensor.permute(tt_tok_emb, (0, 2, 1, 3)) + tt_pos_emb = tt_lib.tensor.permute(tt_pos_emb, (0, 2, 1, 3)) + + tt_x = tt_lib.tensor.bcast(tt_tok_emb, tt_pos_emb, tt_lib.tensor.BcastOpMath.ADD, tt_lib.tensor.BcastOpDim.H) + tt_tok_emb.deallocate() + tt_pos_emb.deallocate() + tt_x = tt_lib.tensor.permute(tt_x, (0, 2, 1, 3)) for block in self.h: tt_x = block.forward(tt_x) @@ -83,3 +91,57 @@ def forward(self, idx: torch.Tensor) -> tt_lib.tensor.Tensor: logits = self.lm_head(tt_x) return logits + + def generate( + self, + idx: torch.Tensor, + max_new_tokens: int = 20, + temperature: int = 1.0, + top_k=None, + ) -> torch.Tensor: + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] + # forward the model to get the logits for the index in the sequence + tt_logits = self.forward(idx_cond) + + logits_shapes = tt_logits.shape() + + slice_list = [ + slice(None), + slice(None), + slice(logits_shapes[2] - 1, logits_shapes[2]), + slice(None), + ] + tt_logits = fallback_ops.tensor_slice(tt_logits, slice_list) + + tt_temperature = fallback_ops.full(tt_logits.shape(), temperature) + + tt_temperature = tt_lib.tensor.recip(tt_temperature) + tt_logits = tt_lib.tensor.mul(tt_logits, tt_temperature) + + logits = tt_to_torch_tensor(tt_logits) + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + + # apply softmax to convert logits to (normalized) probabilities + tt_logits = torch_to_tt_tensor_rm(logits, self.device, put_on_device=False) + tt_probs = fallback_ops.softmax(tt_logits, dim=-1) + probs = tt_to_torch_tensor(tt_probs) + probs = probs.squeeze(0) + probs = probs.squeeze(0) + + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index f7983e9f474..9109f8e1bd7 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -34,6 +34,8 @@ run_perf_models() { env pytest models/experimental/bloom/tests -m $pipeline_type + env pytest models/experimental/nanogpt/tests -m $pipeline_type + ## Merge all the generated reports env python models/perf/merge_perf_results.py }