Skip to content

Commit

Permalink
#3824: move mistral embedding weights to weka
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Nov 29, 2023
1 parent 8cd46ab commit c4a4564
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 7 deletions.
1 change: 0 additions & 1 deletion models/experimental/mistral/demo/gs_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_gs_demo_single_input_inference(batch_size, model_location_generator, de
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
Expand Down
3 changes: 3 additions & 0 deletions models/experimental/mistral/mistral_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def cache_weights_in_weka(model_location_generator, device, dtype, reset_seeds):
# initial weights are stored in "models/experimental/mistral/weights/" and moved to weka path
file_name = "models/experimental/mistral/weights/"
for key, value in state_dict.items():
if "tok_embeddings" in key:
torch.save(value, file_name + str(key) + ".pt")
continue
if len(value.shape) == 1:
value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt

tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
Expand Down
2 changes: 0 additions & 2 deletions models/experimental/mistral/tests/test_perf_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod

mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral")
tokenizer = Tokenizer(str(Path(mistral_path) / "tokenizer.model"))
state_dict = torch.load(mistral_path / "consolidated.00.pth")
base_address = f""
with open(mistral_path / "params.json", "r") as f:
model_args = TtModelArgs(**json.loads(f.read()))
Expand All @@ -58,7 +57,6 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod
tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/"
tt_model = TtTransformer(
args=model_args,
state_dict=state_dict,
device=device,
base_address=base_address,
tt_cache_path=tt_cache_path,
Expand Down
4 changes: 1 addition & 3 deletions models/experimental/mistral/tt/mistral_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
self,
args: TtModelArgs,
device=None,
state_dict=None,
base_address=None,
tt_cache_path=None,
):
Expand All @@ -34,11 +33,10 @@ def __init__(
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.device = device
self.state_dict = state_dict
self.base_address = base_address
assert self.vocab_size > 0

embedding_weights = state_dict["tok_embeddings.weight"]
embedding_weights = torch.load(tt_cache_path + "tok_embeddings.weight.pt")
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim, _weight=embedding_weights)

self.layers = torch.nn.ModuleList(
Expand Down

0 comments on commit c4a4564

Please sign in to comment.