-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#3974: nanogpt uplift and move weights to weka path
- Loading branch information
1 parent
1a2c99f
commit fe3737f
Showing
11 changed files
with
340 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from models.utility_functions import tt2torch_tensor | ||
import torch | ||
import tt_lib | ||
from transformers import GPT2LMHeadModel | ||
from tt_lib.utils import pad_weight | ||
from pathlib import Path | ||
import os | ||
|
||
|
||
def unpad_from_zero(x, desired_shape): | ||
if x.shape()[-1] == desired_shape[-1] and x.shape()[-2] == desired_shape[-2]: | ||
x = tt2torch_tensor(x) | ||
else: | ||
x = x.cpu() | ||
if x.layout() != tt_lib.tensor.Layout.ROW_MAJOR: | ||
x = x.to(tt_lib.tensor.Layout.ROW_MAJOR) | ||
x = x.unpad( | ||
(0, 0, 0, 0), (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1) | ||
) | ||
x = x.to_torch().to(torch.float) | ||
return x | ||
|
||
|
||
def cache_weights_in_weka(device, dtype, reset_seeds): | ||
model_hf = GPT2LMHeadModel.from_pretrained("gpt2") | ||
state_dict = model_hf.state_dict() | ||
weights_dtype = dtype | ||
|
||
# initial weights are stored in "models/experimental/nanogpt/weights/" and moved to weka path | ||
file_name = "models/experimental/nanogpt/weights/" | ||
for key, value in state_dict.items(): | ||
if key.startswith("transformer.wte.") or key.startswith("transformer.wpe."): | ||
torch.save(value, file_name + str(key) + ".pt") | ||
continue | ||
elif len(value.shape) == 0: | ||
continue | ||
while len(value.shape) < 4: | ||
value = value.unsqueeze(0) | ||
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
else: | ||
value = pad_weight(value) | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) | ||
|
||
|
||
"""This function will load weights from the state_dict and check if the needed weights are available in given path. | ||
If they are not available, it will convert torch tensor weights to TT tensor weights and store them in the given path.""" | ||
|
||
|
||
def store_weights(model_version, file_name, base_address, dtype): | ||
model_hf = GPT2LMHeadModel.from_pretrained(model_version) | ||
state_dict = model_hf.state_dict() | ||
weights_dtype = dtype | ||
|
||
for key, value in state_dict.items(): | ||
if base_address == "" and ( | ||
(key.startswith("transformer.wte.") and os.path.exists(file_name + str(key) + ".pt") == False) | ||
or (key.startswith("transformer.wpe.") and os.path.exists(file_name + str(key) + ".pt") == False) | ||
): | ||
torch.save(value, file_name + str(key) + ".pt") | ||
continue | ||
if key.startswith("transformer.wte.") or key.startswith("transformer.wpe.") or (len(value.shape) == 0): | ||
continue | ||
if (os.path.exists(file_name + str(key) + str(weights_dtype) + ".bin")) or ( | ||
key.startswith(base_address) == False and base_address != "" | ||
): | ||
continue | ||
while len(value.shape) < 4: | ||
value = value.unsqueeze(0) | ||
if value.shape[-2] % 32 == 0 and value.shape[-1] % 32 == 0: | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
else: | ||
value = pad_weight(value) | ||
value = tt_lib.tensor.Tensor( | ||
value.reshape(-1).tolist(), | ||
value.shape, | ||
weights_dtype, | ||
tt_lib.tensor.Layout.ROW_MAJOR, | ||
).to(tt_lib.tensor.Layout.TILE) | ||
tt_lib.tensor.dump_tensor(file_name + str(key) + str(weights_dtype) + ".bin", value) | ||
|
||
|
||
def get_tt_cache_path(model_version): | ||
tt_cache_path = Path("/mnt/MLPerf/tt_dnn-models/tt/NanoGPT") / model_version | ||
if tt_cache_path.exists(): | ||
return str(tt_cache_path) + "/" | ||
else: | ||
Path(f"models/experimental/nanogpt/datasets/{model_version}").mkdir(parents=True, exist_ok=True) | ||
return str(Path(f"models/experimental/nanogpt/datasets/{model_version}")) + "/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from transformers import GPT2LMHeadModel | ||
from models.experimental.nanogpt.tt.nanogpt_model import TtGPT | ||
|
||
|
||
def _nanogpt(config, device, tt_cache_path, dtype): | ||
return TtGPT( | ||
config=config, | ||
device=device, | ||
tt_cache_path=tt_cache_path, | ||
dtype=dtype, | ||
) | ||
|
||
|
||
def nanogpt_model(device, dtype) -> TtGPT: | ||
model_name = "gpt2" | ||
model = GPT2LMHeadModel.from_pretrained(model_name) | ||
config = model.config | ||
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/NanoGPT/" | ||
model = _nanogpt(config, device, tt_cache_path, dtype) | ||
return model |
Oops, something went wrong.