From c4a456448dac9bf9f4faf694b93ac780274946e6 Mon Sep 17 00:00:00 2001 From: vigneshkeerthivasanx Date: Tue, 28 Nov 2023 07:13:26 +0000 Subject: [PATCH] #3824: move mistral embedding weights to weka --- models/experimental/mistral/demo/gs_demo.py | 1 - models/experimental/mistral/mistral_utils.py | 3 +++ models/experimental/mistral/tests/test_mistral_transformer.py | 1 - models/experimental/mistral/tests/test_perf_mistral.py | 2 -- models/experimental/mistral/tt/mistral_transformer.py | 4 +--- 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/models/experimental/mistral/demo/gs_demo.py b/models/experimental/mistral/demo/gs_demo.py index da19d3bce31..7bb97c6c26e 100644 --- a/models/experimental/mistral/demo/gs_demo.py +++ b/models/experimental/mistral/demo/gs_demo.py @@ -38,7 +38,6 @@ def test_gs_demo_single_input_inference(batch_size, model_location_generator, de tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtTransformer( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, tt_cache_path=tt_cache_path, diff --git a/models/experimental/mistral/mistral_utils.py b/models/experimental/mistral/mistral_utils.py index c39ac7393fb..a774a09d5bf 100644 --- a/models/experimental/mistral/mistral_utils.py +++ b/models/experimental/mistral/mistral_utils.py @@ -92,6 +92,9 @@ def cache_weights_in_weka(model_location_generator, device, dtype, reset_seeds): # initial weights are stored in "models/experimental/mistral/weights/" and moved to weka path file_name = "models/experimental/mistral/weights/" for key, value in state_dict.items(): + if "tok_embeddings" in key: + torch.save(value, file_name + str(key) + ".pt") + continue if len(value.shape) == 1: value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0) else: diff --git a/models/experimental/mistral/tests/test_mistral_transformer.py b/models/experimental/mistral/tests/test_mistral_transformer.py index b1818ac481f..3922634c4d5 100644 --- a/models/experimental/mistral/tests/test_mistral_transformer.py +++ b/models/experimental/mistral/tests/test_mistral_transformer.py @@ -50,7 +50,6 @@ def test_mistral_transformer_inference(pcc, model_location_generator, device, dt tt_model = TtTransformer( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, tt_cache_path=tt_cache_path, diff --git a/models/experimental/mistral/tests/test_perf_mistral.py b/models/experimental/mistral/tests/test_perf_mistral.py index 111904a2979..5aad280cae7 100644 --- a/models/experimental/mistral/tests/test_perf_mistral.py +++ b/models/experimental/mistral/tests/test_perf_mistral.py @@ -41,7 +41,6 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod mistral_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") tokenizer = Tokenizer(str(Path(mistral_path) / "tokenizer.model")) - state_dict = torch.load(mistral_path / "consolidated.00.pth") base_address = f"" with open(mistral_path / "params.json", "r") as f: model_args = TtModelArgs(**json.loads(f.read())) @@ -58,7 +57,6 @@ def run_perf_mistral(expected_inference_time, expected_compile_time, device, mod tt_cache_path = "/mnt/MLPerf/tt_dnn-models/tt/Mistral/" tt_model = TtTransformer( args=model_args, - state_dict=state_dict, device=device, base_address=base_address, tt_cache_path=tt_cache_path, diff --git a/models/experimental/mistral/tt/mistral_transformer.py b/models/experimental/mistral/tt/mistral_transformer.py index 783442a2d58..7b87ee2307c 100644 --- a/models/experimental/mistral/tt/mistral_transformer.py +++ b/models/experimental/mistral/tt/mistral_transformer.py @@ -25,7 +25,6 @@ def __init__( self, args: TtModelArgs, device=None, - state_dict=None, base_address=None, tt_cache_path=None, ): @@ -34,11 +33,10 @@ def __init__( self.vocab_size = args.vocab_size self.n_layers = args.n_layers self.device = device - self.state_dict = state_dict self.base_address = base_address assert self.vocab_size > 0 - embedding_weights = state_dict["tok_embeddings.weight"] + embedding_weights = torch.load(tt_cache_path + "tok_embeddings.weight.pt") self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim, _weight=embedding_weights) self.layers = torch.nn.ModuleList(