-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#3974: nanogpt uplift and move weights to weka path
- Loading branch information
1 parent
5aeccad
commit 7d782bf
Showing
11 changed files
with
247 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from models.utility_functions import tt2torch_tensor | ||
import torch | ||
import tt_lib | ||
from transformers import GPT2LMHeadModel | ||
from tt_lib.utils import pad_weight | ||
|
||
|
||
def unpad_from_zero(x, desired_shape): | ||
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]: | ||
x = tt2torch_tensor(x) | ||
else: | ||
x = x.cpu() | ||
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR: | ||
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR) | ||
x = x.unpad( | ||
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1) | ||
) | ||
x = x.to_torch().to(torch.float) | ||
return x | ||
|
||
|
||
def cache_weights_in_weka(device, dtype, reset_seeds): | ||
model_hf = GPT2LMHeadModel.from_pretrained("gpt2") | ||
state_dict = model_hf.state_dict() | ||
weights_dtype = dtype | ||
|
||
# initial weights are stored in "models/experimental/nanogpt/weights/" and moved to weka path | ||
file_name = "models/experimental/nanogpt/weights/" | ||
for key, value in state_dict.items(): | ||
if key.startswith("transformer.wte.") or key.startswith("transformer.wpe."): | ||
torch.save(value, file_name + str(key) + ".pt") | ||
continue | ||
elif len(value.shape) == 0: | ||
continue | ||
if len(value.shape) == 1: | ||
value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0) | ||
elif len(value.shape) == 3: | ||
value = value.unsqueeze(0) | ||
elif len(value.shape) == 2: | ||
value = value.unsqueeze(0).unsqueeze(0) | ||
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
else: | ||
value = pad_weight(value) | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from transformers import GPT2LMHeadModel | ||
from models.experimental.nanogpt.tt.nanogpt_model import TtGPT | ||
|
||
|
||
def _nanogpt(config, device, tt_cache_path, dtype): | ||
return TtGPT( | ||
config=config, | ||
device=device, | ||
tt_cache_path=tt_cache_path, | ||
dtype=dtype, | ||
) | ||
|
||
|
||
def nanogpt_model(device, dtype) -> TtGPT: | ||
model_name = "gpt2" | ||
model = GPT2LMHeadModel.from_pretrained(model_name) | ||
config = model.config | ||
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/NanoGPT/" | ||
model = _nanogpt(config, device, tt_cache_path, dtype) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.