Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#3974: nanogpt uplift and move weights to weka path #4221

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions models/experimental/nanogpt/nanogpt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from models.utility_functions import tt2torch_tensor
import torch
import tt_lib
from transformers import GPT2LMHeadModel
from tt_lib.utils import pad_weight
from pathlib import Path
import os


def unpad_from_zero(x, desired_shape):
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]:
x = tt2torch_tensor(x)
else:
x = x.cpu()
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR:
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR)
x = x.unpad(
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1)
)
x = x.to_torch().to(torch.float)
return x


def cache_weights_in_weka(device, dtype, reset_seeds):
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
state_dict = model_hf.state_dict()
weights_dtype = dtype

# initial weights are stored in "models/experimental/nanogpt/weights/" and moved to weka path
file_name = "models/experimental/nanogpt/weights/"
for key, value in state_dict.items():
if key.startswith("transformer.wte.") or key.startswith("transformer.wpe."):
torch.save(value, file_name + str(key) + ".pt")
continue
elif len(value.shape) == 0:
continue
while len(value.shape) < 4:
value = value.unsqueeze(0)
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0:
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
else:
value = pad_weight(value)
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value)


"""This function will load weights from the state_dict and check if the needed weights are available in given path.
If they are not available, it will convert torch tensor weights to TT tensor weights and store them in the given path."""


def store_weights(model_version, file_name, base_address, dtype):
model_hf = GPT2LMHeadModel.from_pretrained(model_version)
state_dict = model_hf.state_dict()
weights_dtype = dtype

for key, value in state_dict.items():
if base_address == "" and (
(key.startswith("transformer.wte.") and os.path.exists(file_name + str(key) + ".pt") == False)
or (key.startswith("transformer.wpe.") and os.path.exists(file_name + str(key) + ".pt") == False)
):
torch.save(value, file_name + str(key) + ".pt")
continue
if key.startswith("transformer.wte.") or key.startswith("transformer.wpe.") or (len(value.shape) == 0):
continue
if (os.path.exists(file_name + str(key) + str(weights_dtype) + ".bin")) or (
key.startswith(base_address) == False and base_address != ""
):
continue
while len(value.shape) < 4:
value = value.unsqueeze(0)
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0:
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
else:
value = pad_weight(value)
value = tt_lib.tensor.Tensor(
value.reshape(-1).tolist(),
value.shape,
weights_dtype,
tt_lib.tensor.Layout.ROW_MAJOR,
).to(tt_lib.tensor.Layout.TILE)
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value)


def get_tt_cache_path(model_version):
tt_cache_path = Path("/mnt/MLPerf/tt_dnn-models/tt/NanoGPT") / model_version
if tt_cache_path.exists():
return str(tt_cache_path) + "/"
else:
Path(f"models/experimental/nanogpt/datasets/{model_version}").mkdir(parents=True, exist_ok=True)
return str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/"
25 changes: 19 additions & 6 deletions models/experimental/nanogpt/tests/test_nanogpt_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

import torch
import pytest
import tt_lib
import os
from pathlib import Path

from transformers import GPT2LMHeadModel


from loguru import logger
import models.experimental.nanogpt.tt.nanogpt_attention as nanogpt_attention
from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights

from models.utility_functions import (
tt_to_torch_tensor,
Expand All @@ -19,16 +22,17 @@
)


@pytest.mark.parametrize(
"dtype",
(tt_lib.tensor.DataType.BFLOAT16,),
)
@pytest.mark.parametrize(
"pcc",
((0.99,),),
)

def test_nanogpt_attn(device, pcc, reset_seeds):

def test_nanogpt_attn(device, pcc, dtype, reset_seeds):
# Prepare input
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd = model_hf.state_dict()
config = model_hf.config
model_hf.eval()
block = 0
Expand All @@ -38,8 +42,17 @@ def test_nanogpt_attn(device, pcc, reset_seeds):
pt_attn = model_hf.transformer.h[block].attn
pt_out = pt_attn.forward(test_in)

model_version = "gpt2"
tt_cache_path = get_tt_cache_path(model_version)

if (
tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/")
and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320
):
store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address)

tt_test_in = torch_to_tt_tensor_rm(test_in, device)
tt_attn = nanogpt_attention.TtCausalSelfAttention(config, sd, base_address, device)
tt_attn = nanogpt_attention.TtCausalSelfAttention(config, base_address, device, tt_cache_path, dtype)

tt_out = tt_attn.forward(tt_test_in)

Expand Down
22 changes: 18 additions & 4 deletions models/experimental/nanogpt/tests/test_nanogpt_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

import torch
import pytest
import tt_lib
from pathlib import Path
import os

from transformers import GPT2LMHeadModel

from loguru import logger
import models.experimental.nanogpt.tt.nanogpt_block as nanogpt_block
from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights

from models.utility_functions import (
tt_to_torch_tensor,
Expand All @@ -18,14 +22,16 @@
)


@pytest.mark.parametrize(
"dtype",
(tt_lib.tensor.DataType.BFLOAT16,),
)
@pytest.mark.parametrize(
"pcc",
((0.99,),),
)
def test_nanogpt_block(device, pcc, reset_seeds):

def test_nanogpt_block(device, pcc, dtype, reset_seeds):
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd = model_hf.state_dict()
config = model_hf.config
model_hf.eval()
block = 0
Expand All @@ -36,8 +42,16 @@ def test_nanogpt_block(device, pcc, reset_seeds):
pt_out = pt_block.forward(test_in)

tt_test_in = torch_to_tt_tensor_rm(test_in, device)
model_version = "gpt2"
tt_cache_path = get_tt_cache_path(model_version)

if (
tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/")
and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320
):
store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address)

tt_block = nanogpt_block.TtBlock(config, sd, base_address, device)
tt_block = nanogpt_block.TtBlock(config, base_address, device, tt_cache_path, dtype)
tt_block.eval()

tt_out = tt_block.forward(tt_test_in)
Expand Down
23 changes: 19 additions & 4 deletions models/experimental/nanogpt/tests/test_nanogpt_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import torch
import pytest
import tt_lib
from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights
from pathlib import Path
import os

from transformers import GPT2LMHeadModel

Expand All @@ -19,22 +23,33 @@
)


@pytest.mark.parametrize(
"dtype",
(tt_lib.tensor.DataType.BFLOAT16,),
)
@pytest.mark.parametrize(
"pcc",
((0.99,),),
)
def test_nanogpt_mlp(device, pcc, reset_seeds):

def test_nanogpt_mlp(device, pcc, dtype, reset_seeds):
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd = model_hf.state_dict()
config = model_hf.config
model_hf.eval()
block = 0
base_address = f"transformer.h.{block}.mlp"

test_in = torch.rand(1, 43, 768)
tt_test_in = torch_to_tt_tensor_rm(test_in, device)
tt_mlp = nanogpt_mlp.TtMLP(base_address, config, sd, device)
model_version = "gpt2"
tt_cache_path = get_tt_cache_path(model_version)

if (
tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/")
and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320
):
store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address)

tt_mlp = nanogpt_mlp.TtMLP(base_address, config, device, tt_cache_path, dtype)

tt_out = tt_mlp.forward(tt_test_in)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,32 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import tt_lib
import pytest

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from models.experimental.nanogpt.nanogpt_utils import get_tt_cache_path, store_weights
from pathlib import Path
import os

from loguru import logger
import models.experimental.nanogpt.tt.nanogpt_model as nanogpt_model

from models.utility_functions import tt_to_torch_tensor, comp_allclose, comp_pcc



@pytest.mark.parametrize(
"dtype",
(tt_lib.tensor.DataType.BFLOAT16,),
)
@pytest.mark.parametrize(
"pcc, prompt",
((0.99, "Hello, my dog is a little"),),
((0.98, "Hello, my dog is a little"),),
)
def test_nanogpt_model_real(device, pcc, prompt, reset_seeds):

def test_nanogpt_model_real(device, pcc, prompt, dtype, reset_seeds):
# Prepare input
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
sd = model_hf.state_dict()
model_hf.eval()

inputs = tokenizer(prompt, return_tensors="pt", padding=False)
Expand All @@ -33,7 +37,17 @@ def test_nanogpt_model_real(device, pcc, prompt, reset_seeds):

config = model_hf.config

tt_model = nanogpt_model.TtGPT(config, sd, device)
base_address = ""
model_version = "gpt2"
tt_cache_path = get_tt_cache_path(model_version)

if (
tt_cache_path == (str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/")
and len(os.listdir(f"models/experimental/nanogpt/datasets/{model_version}")) < 320
):
store_weights(model_version=model_version, file_name=tt_cache_path, dtype=dtype, base_address=base_address)

tt_model = nanogpt_model.TtGPT(config, device, tt_cache_path, dtype)

tt_out = tt_model.forward(inputs.input_ids)

Expand Down
24 changes: 24 additions & 0 deletions models/experimental/nanogpt/tt/nanogpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from transformers import GPT2LMHeadModel
from models.experimental.nanogpt.tt.nanogpt_model import TtGPT


def _nanogpt(config, device, tt_cache_path, dtype):
return TtGPT(
config=config,
device=device,
tt_cache_path=tt_cache_path,
dtype=dtype,
)


def nanogpt_model(device, dtype) -> TtGPT:
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
config = model.config
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/NanoGPT/"
model = _nanogpt(config, device, tt_cache_path, dtype)
return model
Loading