From 88dba1a1e9076efbac6a2d4b311944694dc45506 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Tue, 5 Dec 2023 23:09:49 +0000 Subject: [PATCH] #4003: updated ttnn.model_preprocessing to keep the structure of the model weights --- .../reference/torch_functional_bert.py | 53 ++- .../tt/ttnn_functional_bert.py | 53 ++- .../tt/ttnn_optimized_functional_bert.py | 45 +-- .../reference/torch_functional_bloom.py | 119 +++--- .../tt/ttnn_functional_bloom.py | 69 ++-- .../tt/ttnn_optimized_functional_bloom.py | 71 ++-- .../ttnn/integration_tests/bert/test_bert.py | 48 +-- .../bloom/test_bloom_for_causal_lm.py | 35 +- .../test_bloom_for_question_answering.py | 29 +- ttnn/model_preprocessing.py | 376 +++++++++++------- 10 files changed, 451 insertions(+), 447 deletions(-) diff --git a/models/experimental/functional_bert/reference/torch_functional_bert.py b/models/experimental/functional_bert/reference/torch_functional_bert.py index c281fbdada2..5acc6ea8c0b 100644 --- a/models/experimental/functional_bert/reference/torch_functional_bert.py +++ b/models/experimental/functional_bert/reference/torch_functional_bert.py @@ -70,44 +70,43 @@ def torch_bert_encoder( attention_mask, parameters, *, - encoder_index, head_size, ): *_, hidden_size = hidden_states.shape multi_head_attention_output = torch_multi_head_attention( hidden_states, attention_mask, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"], + parameters.attention.self.query.weight, + parameters.attention.self.query.bias, + parameters.attention.self.key.weight, + parameters.attention.self.key.bias, + parameters.attention.self.value.weight, + parameters.attention.self.value.bias, + parameters.attention.output.dense.weight, + parameters.attention.output.dense.bias, head_size=head_size, ) multi_head_attention_add_and_layer_norm_output = F.layer_norm( hidden_states + multi_head_attention_output, (hidden_size,), - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"], + parameters.attention.output.LayerNorm.weight, + parameters.attention.output.LayerNorm.bias, ) feedforward_output = torch_feedforward( multi_head_attention_add_and_layer_norm_output, - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"].T, - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"], + parameters.intermediate.dense.weight, + parameters.intermediate.dense.bias, + parameters.output.dense.weight, + parameters.output.dense.bias, ) feedforward_add_and_layer_norm_output = F.layer_norm( multi_head_attention_add_and_layer_norm_output + feedforward_output, (hidden_size,), - parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"], + parameters.output.LayerNorm.weight, + parameters.output.LayerNorm.bias, ) return feedforward_add_and_layer_norm_output @@ -119,28 +118,26 @@ def torch_bert( attention_mask, parameters, *, - num_encoders, head_size, ): - word_embeddings = F.embedding(input_ids, parameters["bert.embeddings.word_embeddings.weight"]) - token_type_embeddings = F.embedding(token_type_ids, parameters["bert.embeddings.token_type_embeddings.weight"]) + word_embeddings = F.embedding(input_ids, parameters.bert.embeddings.word_embeddings.weight) + token_type_embeddings = F.embedding(token_type_ids, parameters.bert.embeddings.token_type_embeddings.weight) encoder_input = word_embeddings + token_type_embeddings *_, hidden_size = encoder_input.shape encoder_input = F.layer_norm( encoder_input, (hidden_size,), - parameters["bert.embeddings.LayerNorm.weight"], - parameters["bert.embeddings.LayerNorm.bias"], + parameters.bert.embeddings.LayerNorm.weight, + parameters.bert.embeddings.LayerNorm.bias, ) encoder_output = None - for encoder_index in range(num_encoders): + for encoder_parameters in parameters.bert.encoder.layer: encoder_output = torch_bert_encoder( encoder_input, attention_mask, - parameters, - encoder_index=encoder_index, + encoder_parameters, head_size=head_size, ) encoder_input = encoder_output @@ -153,7 +150,6 @@ def torch_bert_for_question_answering( attention_mask, parameters, *, - num_encoders, head_size, ): bert_output = torch_bert( @@ -161,11 +157,10 @@ def torch_bert_for_question_answering( token_type_ids, attention_mask, parameters, - num_encoders=num_encoders, head_size=head_size, ) qa_outputs = bert_output - qa_outputs = qa_outputs @ parameters["qa_outputs.weight"].T - qa_outputs = qa_outputs + parameters["qa_outputs.bias"] + qa_outputs = qa_outputs @ parameters.qa_outputs.weight + qa_outputs = qa_outputs + parameters.qa_outputs.bias return qa_outputs diff --git a/models/experimental/functional_bert/tt/ttnn_functional_bert.py b/models/experimental/functional_bert/tt/ttnn_functional_bert.py index a11cb508aaf..24dac92e95d 100644 --- a/models/experimental/functional_bert/tt/ttnn_functional_bert.py +++ b/models/experimental/functional_bert/tt/ttnn_functional_bert.py @@ -75,41 +75,40 @@ def ttnn_bert_encoder( attention_mask, parameters, *, - encoder_index, head_size, ): multi_head_attention_output = ttnn_multi_head_attention( hidden_states, attention_mask, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"], + parameters.attention.self.query.weight, + parameters.attention.self.query.bias, + parameters.attention.self.key.weight, + parameters.attention.self.key.bias, + parameters.attention.self.value.weight, + parameters.attention.self.value.bias, + parameters.attention.output.dense.weight, + parameters.attention.output.dense.bias, head_size=head_size, ) hidden_states = ttnn.experimental.layer_norm( hidden_states + multi_head_attention_output, - weight=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"], - bias=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"], + weight=parameters.attention.output.LayerNorm.weight, + bias=parameters.attention.output.LayerNorm.bias, ) feedforward_output = ttnn_feedforward( hidden_states, - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"], + parameters.intermediate.dense.weight, + parameters.intermediate.dense.bias, + parameters.output.dense.weight, + parameters.output.dense.bias, ) hidden_states = ttnn.experimental.layer_norm( hidden_states + feedforward_output, - weight=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"], - bias=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"], + weight=parameters.output.LayerNorm.weight, + bias=parameters.output.LayerNorm.bias, ) return hidden_states @@ -121,30 +120,28 @@ def ttnn_bert( attention_mask, parameters, *, - num_encoders, head_size, ): word_embeddings = ttnn.embedding( - input_ids, parameters["bert.embeddings.word_embeddings.weight"], layout=ttnn.TILE_LAYOUT + input_ids, parameters.bert.embeddings.word_embeddings.weight, layout=ttnn.TILE_LAYOUT ) token_type_embeddings = ttnn.embedding( - token_type_ids, parameters["bert.embeddings.token_type_embeddings.weight"], layout=ttnn.TILE_LAYOUT + token_type_ids, parameters.bert.embeddings.token_type_embeddings.weight, layout=ttnn.TILE_LAYOUT ) encoder_input = word_embeddings + token_type_embeddings encoder_input = ttnn.experimental.layer_norm( encoder_input, - weight=parameters[f"bert.embeddings.LayerNorm.weight"], - bias=parameters[f"bert.embeddings.LayerNorm.bias"], + weight=parameters.bert.embeddings.LayerNorm.weight, + bias=parameters.bert.embeddings.LayerNorm.bias, ) encoder_output = None - for encoder_index in range(num_encoders): + for encoder_parameters in parameters.bert.encoder.layer: encoder_output = ttnn_bert_encoder( encoder_input, attention_mask, - parameters, - encoder_index=encoder_index, + encoder_parameters, head_size=head_size, ) encoder_input = encoder_output @@ -157,7 +154,6 @@ def ttnn_bert_for_question_answering( attention_mask, parameters, *, - num_encoders, head_size, ): bert_output = ttnn_bert( @@ -165,12 +161,11 @@ def ttnn_bert_for_question_answering( token_type_ids, attention_mask, parameters, - num_encoders=num_encoders, head_size=head_size, ) qa_outputs = bert_output - qa_outputs = qa_outputs @ parameters["qa_outputs.weight"] - qa_outputs = qa_outputs + parameters["qa_outputs.bias"] + qa_outputs = qa_outputs @ parameters.qa_outputs.weight + qa_outputs = qa_outputs + parameters.qa_outputs.bias return qa_outputs diff --git a/models/experimental/functional_bert/tt/ttnn_optimized_functional_bert.py b/models/experimental/functional_bert/tt/ttnn_optimized_functional_bert.py index 906236a6145..ce3a4cbffd5 100644 --- a/models/experimental/functional_bert/tt/ttnn_optimized_functional_bert.py +++ b/models/experimental/functional_bert/tt/ttnn_optimized_functional_bert.py @@ -110,24 +110,23 @@ def ttnn_optimized_bert_encoder( attention_mask, parameters, *, - encoder_index, head_size, ): multi_head_attention_output = ttnn_optimized_multi_head_attention( hidden_states, attention_mask, - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.fused_qkv.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.self.fused_qkv.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"], + parameters.attention.self.fused_qkv.weight, + parameters.attention.self.fused_qkv.bias, + parameters.attention.output.dense.weight, + parameters.attention.output.dense.bias, head_size=head_size, ) multi_head_attention_add_and_layer_norm_output = ttnn.experimental.layer_norm( hidden_states, residual_input=multi_head_attention_output, - weight=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"], - bias=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"], + weight=parameters.attention.output.LayerNorm.weight, + bias=parameters.attention.output.LayerNorm.bias, memory_config=ttnn.L1_MEMORY_CONFIG, ) ttnn.deallocate(hidden_states) @@ -135,17 +134,17 @@ def ttnn_optimized_bert_encoder( feedforward_output = ttnn_optimized_feedforward( multi_head_attention_add_and_layer_norm_output, - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"], - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"], - parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"], + parameters.intermediate.dense.weight, + parameters.intermediate.dense.bias, + parameters.output.dense.weight, + parameters.output.dense.bias, ) feedforward_add_and_layer_norm_output = ttnn.experimental.layer_norm( multi_head_attention_add_and_layer_norm_output, residual_input=feedforward_output, - weight=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"], - bias=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"], + weight=parameters.output.LayerNorm.weight, + bias=parameters.output.LayerNorm.bias, memory_config=ttnn.L1_MEMORY_CONFIG, ) ttnn.deallocate(multi_head_attention_add_and_layer_norm_output) @@ -160,21 +159,20 @@ def ttnn_optimized_bert( attention_mask, parameters, *, - num_encoders, head_size, ): import tt_lib as ttl word_embeddings = ttnn.embedding( input_ids, - parameters["bert.embeddings.word_embeddings.weight"], + parameters.bert.embeddings.word_embeddings.weight, layout=ttnn.TILE_LAYOUT, ) ttnn.deallocate(input_ids) token_type_embeddings = ttnn.embedding( token_type_ids, - parameters["bert.embeddings.token_type_embeddings.weight"], + parameters.bert.embeddings.token_type_embeddings.weight, layout=ttnn.TILE_LAYOUT, ) ttnn.deallocate(token_type_ids) @@ -185,19 +183,18 @@ def ttnn_optimized_bert( encoder_input = ttnn.experimental.layer_norm( embeddings, - weight=parameters[f"bert.embeddings.LayerNorm.weight"], - bias=parameters[f"bert.embeddings.LayerNorm.bias"], + weight=parameters.bert.embeddings.LayerNorm.weight, + bias=parameters.bert.embeddings.LayerNorm.bias, memory_config=ttnn.L1_MEMORY_CONFIG, ) ttnn.deallocate(embeddings) encoder_output = None - for encoder_index in range(num_encoders): + for encoder_parameters in parameters.bert.encoder.layer: encoder_output = ttnn_optimized_bert_encoder( encoder_input, attention_mask, - parameters, - encoder_index=encoder_index, + encoder_parameters, head_size=head_size, ) encoder_output = ttnn.reallocate(encoder_output) @@ -212,7 +209,6 @@ def ttnn_optimized_bert_for_question_answering( attention_mask, parameters, *, - num_encoders, head_size, ): bert_output = ttnn_optimized_bert( @@ -220,14 +216,13 @@ def ttnn_optimized_bert_for_question_answering( token_type_ids, attention_mask, parameters, - num_encoders=num_encoders, head_size=head_size, ) qa_outputs = ttnn.linear( bert_output, - parameters["qa_outputs.weight"], - bias=parameters["qa_outputs.bias"], + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, memory_config=ttnn.L1_MEMORY_CONFIG, ) diff --git a/models/experimental/functional_bloom/reference/torch_functional_bloom.py b/models/experimental/functional_bloom/reference/torch_functional_bloom.py index bdddb1650cd..2ce33660efb 100644 --- a/models/experimental/functional_bloom/reference/torch_functional_bloom.py +++ b/models/experimental/functional_bloom/reference/torch_functional_bloom.py @@ -10,16 +10,7 @@ import torch.utils.checkpoint from torch.nn import functional as F - - -def transpose(tensor): - ndim = len(tensor.shape) - if ndim < 2: - return tensor - else: - dims = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) - new_tensor = torch.permute(tensor, dims=dims) - return new_tensor +import transformers # From transformers/models/bloom/modeling_bloom.py @@ -196,34 +187,34 @@ def mlp( return hidden_states -def bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): - inputs_embeds = F.embedding(input_ids, parameters["transformer.word_embeddings.weight"]) +def bloom(input_ids, alibi, causal_mask, parameters, num_heads): + inputs_embeds = F.embedding(input_ids, parameters.transformer.word_embeddings.weight) hidden_size = inputs_embeds.shape[2] head_size = hidden_size // num_heads hidden_states = F.layer_norm( inputs_embeds, (hidden_size,), - parameters[f"transformer.word_embeddings_layernorm.weight"], - parameters[f"transformer.word_embeddings_layernorm.bias"], + parameters.transformer.word_embeddings_layernorm.weight, + parameters.transformer.word_embeddings_layernorm.bias, ) - for i in range(0, hidden_layers): + for layer_parameters in parameters.transformer.h: normalized_hidden_states = F.layer_norm( hidden_states, (hidden_size,), - parameters[f"transformer.h.{i}.input_layernorm.weight"], - parameters[f"transformer.h.{i}.input_layernorm.bias"], + layer_parameters.input_layernorm.weight, + layer_parameters.input_layernorm.bias, ) attention_output = multi_head_attention( normalized_hidden_states, alibi, causal_mask, - transpose(parameters[f"transformer.h.{i}.self_attention.query_key_value.weight"]), - parameters[f"transformer.h.{i}.self_attention.query_key_value.bias"], - transpose(parameters[f"transformer.h.{i}.self_attention.dense.weight"]), - parameters[f"transformer.h.{i}.self_attention.dense.bias"], + layer_parameters.self_attention.query_key_value.weight, + layer_parameters.self_attention.query_key_value.bias, + layer_parameters.self_attention.dense.weight, + layer_parameters.self_attention.dense.bias, head_size=head_size, ) attention_output += hidden_states @@ -231,16 +222,16 @@ def bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): normalized_attention_output = F.layer_norm( attention_output, (hidden_size,), - parameters[f"transformer.h.{i}.post_attention_layernorm.weight"], - parameters[f"transformer.h.{i}.post_attention_layernorm.bias"], + layer_parameters.post_attention_layernorm.weight, + layer_parameters.post_attention_layernorm.bias, ) mlp_output = mlp( normalized_attention_output, - transpose(parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"]), - parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"], - transpose(parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"]), - parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"], + layer_parameters.mlp.dense_h_to_4h.weight, + layer_parameters.mlp.dense_h_to_4h.bias, + layer_parameters.mlp.dense_4h_to_h.weight, + layer_parameters.mlp.dense_4h_to_h.bias, ) mlp_output += attention_output hidden_states = mlp_output @@ -248,15 +239,15 @@ def bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): hidden_states = F.layer_norm( hidden_states, (hidden_size,), - parameters[f"transformer.ln_f.weight"], - parameters[f"transformer.ln_f.bias"], + parameters.transformer.ln_f.weight, + parameters.transformer.ln_f.bias, ) return hidden_states -def bloom_for_causal_lm(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): +def bloom_for_causal_lm(input_ids, alibi, causal_mask, parameters, num_heads): start = time.time() - hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers) + hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads) end = time.time() batch_size, _ = input_ids.shape duration = end - start @@ -264,7 +255,7 @@ def bloom_for_causal_lm(input_ids, alibi, causal_mask, parameters, num_heads, hi logger.info(f"Samples per second: {1 / duration * batch_size}") # return logits - return hidden_states @ transpose(parameters[f"lm_head.weight"]) + return hidden_states @ parameters.lm_head.weight def preprocess_inputs( @@ -295,36 +286,40 @@ def preprocess_inputs( return padded_input_ids, alibi, causal_mask -def preprocess_parameters(parameters, num_heads): - preprocessed_parameters = {} - for name, parameter in parameters.items(): +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.bloom.modeling_bloom.BloomAttention): + weight = torch_model.query_key_value.weight + bias = torch_model.query_key_value.bias + + assert weight.shape[-1] == 1024 + num_heads = 16 + + three_times_hidden_size, _ = weight.shape + hidden_size = three_times_hidden_size // 3 + head_size = hidden_size // num_heads + # Store QKV one after another instead of interleaving heads - if "query_key_value.weight" in name: - three_times_hidden_size, _ = parameter.shape - hidden_size = three_times_hidden_size // 3 - head_size = hidden_size // num_heads - - parameter = parameter.view(num_heads, 3, head_size, hidden_size) - query, key, value = parameter[:, 0], parameter[:, 1], parameter[:, 2] - query = torch.reshape(query, (hidden_size, hidden_size)) - key = torch.reshape(key, (hidden_size, hidden_size)) - value = torch.reshape(value, (hidden_size, hidden_size)) - preprocessed_parameter = torch.cat([query, key, value], dim=0) - preprocessed_parameters[name] = preprocessed_parameter + weight = weight.view(num_heads, 3, head_size, hidden_size) + query, key, value = weight[:, 0], weight[:, 1], weight[:, 2] + query = torch.reshape(query, (hidden_size, hidden_size)) + key = torch.reshape(key, (hidden_size, hidden_size)) + value = torch.reshape(value, (hidden_size, hidden_size)) + preprocessed_weight = torch.cat([query, key, value], dim=0) # Store QKV one after another instead of interleaving heads - elif "query_key_value.bias" in name: - (three_times_hidden_size,) = parameter.shape - hidden_size = three_times_hidden_size // 3 - head_size = hidden_size // num_heads - - parameter = parameter.view(num_heads, 3, head_size) - query, key, value = parameter[:, 0], parameter[:, 1], parameter[:, 2] - query = torch.reshape(query, (hidden_size,)) - key = torch.reshape(key, (hidden_size,)) - value = torch.reshape(value, (hidden_size,)) - preprocessed_parameter = torch.cat([query, key, value], dim=0) - preprocessed_parameters[name] = preprocessed_parameter - else: - preprocessed_parameters[name] = parameter - return preprocessed_parameters + bias = bias.view(num_heads, 3, head_size) + query, key, value = bias[:, 0], bias[:, 1], bias[:, 2] + query = torch.reshape(query, (hidden_size,)) + key = torch.reshape(key, (hidden_size,)) + value = torch.reshape(value, (hidden_size,)) + preprocessed_bias = torch.cat([query, key, value], dim=0) + + parameters = {"query_key_value": {}, "dense": {}} + + parameters["query_key_value"]["weight"] = preprocessed_weight.T.contiguous() + parameters["query_key_value"]["bias"] = preprocessed_bias + + parameters["dense"]["weight"] = torch_model.dense.weight.T.contiguous() + parameters["dense"]["bias"] = torch_model.dense.bias + return parameters diff --git a/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py b/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py index 1accf79f0da..20a9a6deb70 100644 --- a/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py +++ b/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py @@ -198,11 +198,10 @@ def bloom( causal_mask, parameters, num_heads, - hidden_layers, ): inputs_embeds = ttnn.embedding( input_ids, - parameters["transformer.word_embeddings.weight"], + parameters.transformer.word_embeddings.weight, layout=ttnn.TILE_LAYOUT, ) hidden_size = inputs_embeds.shape[-1] @@ -210,41 +209,41 @@ def bloom( hidden_states = ttnn.experimental.layer_norm( inputs_embeds, - weight=parameters[f"transformer.word_embeddings_layernorm.weight"], - bias=parameters[f"transformer.word_embeddings_layernorm.bias"], + weight=parameters.transformer.word_embeddings_layernorm.weight, + bias=parameters.transformer.word_embeddings_layernorm.bias, ) - for i in range(0, hidden_layers): + for layer_parameters in parameters.transformer.h: normalized_hidden_states = ttnn.experimental.layer_norm( hidden_states, - weight=parameters[f"transformer.h.{i}.input_layernorm.weight"], - bias=parameters[f"transformer.h.{i}.input_layernorm.bias"], + weight=layer_parameters.input_layernorm.weight, + bias=layer_parameters.input_layernorm.bias, ) attention_output = multi_head_attention( normalized_hidden_states, alibi, causal_mask, - parameters[f"transformer.h.{i}.self_attention.query_key_value.weight"], - parameters[f"transformer.h.{i}.self_attention.query_key_value.bias"], - parameters[f"transformer.h.{i}.self_attention.dense.weight"], - parameters[f"transformer.h.{i}.self_attention.dense.bias"], + layer_parameters.self_attention.query_key_value.weight, + layer_parameters.self_attention.query_key_value.bias, + layer_parameters.self_attention.dense.weight, + layer_parameters.self_attention.dense.bias, head_size=head_size, ) attention_output = attention_output + hidden_states normalized_attention_output = ttnn.experimental.layer_norm( attention_output, - weight=parameters[f"transformer.h.{i}.post_attention_layernorm.weight"], - bias=parameters[f"transformer.h.{i}.post_attention_layernorm.bias"], + weight=layer_parameters.post_attention_layernorm.weight, + bias=layer_parameters.post_attention_layernorm.bias, ) mlp_output = mlp( normalized_attention_output, - parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"], - parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"], - parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"], - parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"], + layer_parameters.mlp.dense_h_to_4h.weight, + layer_parameters.mlp.dense_h_to_4h.bias, + layer_parameters.mlp.dense_4h_to_h.weight, + layer_parameters.mlp.dense_4h_to_h.bias, ) mlp_output = mlp_output + attention_output @@ -252,27 +251,27 @@ def bloom( hidden_states = ttnn.experimental.layer_norm( hidden_states, - weight=parameters[f"transformer.ln_f.weight"], - bias=parameters[f"transformer.ln_f.bias"], + weight=parameters.transformer.ln_f.weight, + bias=parameters.transformer.ln_f.bias, ) return hidden_states -def bloom_for_causal_lm(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): - hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers) +def bloom_for_causal_lm(input_ids, alibi, causal_mask, parameters, num_heads): + hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads) # Unfortuntely we do not have the ability to handle large tensors yet. So running final matmul ising torch is a workaround. hidden_states = ttnn.from_device(hidden_states) hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT) hidden_states = ttnn.to_torch(hidden_states).to(torch.float32) - output = hidden_states @ parameters[f"lm_head.weight"] + output = hidden_states @ parameters.lm_head.weight return output -def bloom_for_question_answering(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers): - hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers) - hidden_states = ttnn.linear(hidden_states, parameters[f"qa_outputs.weight"], bias=parameters[f"qa_outputs.bias"]) +def bloom_for_question_answering(input_ids, alibi, causal_mask, parameters, num_heads): + hidden_states = bloom(input_ids, alibi, causal_mask, parameters, num_heads) + hidden_states = ttnn.linear(hidden_states, parameters.qa_outputs.weight, bias=parameters.qa_outputs.bias) return hidden_states @@ -315,7 +314,7 @@ def preprocess_inputs( return padded_input_ids, alibi, causal_mask -def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): +def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.bloom.modeling_bloom.BloomAttention): weight = torch_model.query_key_value.weight @@ -344,17 +343,11 @@ def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): value = torch.reshape(value, (hidden_size,)) preprocessed_bias = torch.cat([query, key, value], dim=0) - parameters[f"{full_name}query_key_value.weight"] = preprocess_linear_weight( - preprocessed_weight, dtype=parameters_config.linear_weight_dtype - ) - parameters[f"{full_name}query_key_value.bias"] = preprocess_linear_bias( - preprocessed_bias, dtype=parameters_config.linear_bias_dtype - ) + parameters = {"query_key_value": {}, "dense": {}} - parameters[f"{full_name}dense.weight"] = preprocess_linear_weight( - torch_model.dense.weight, dtype=parameters_config.linear_weight_dtype - ) - parameters[f"{full_name}dense.bias"] = preprocess_linear_bias( - torch_model.dense.bias, dtype=parameters_config.linear_bias_dtype - ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["dense"]["weight"] = preprocess_linear_weight(torch_model.dense.weight, dtype=ttnn.bfloat16) + parameters["dense"]["bias"] = preprocess_linear_bias(torch_model.dense.bias, dtype=ttnn.bfloat16) return parameters diff --git a/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py b/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py index 04fc631e4a9..92a5b417cfe 100644 --- a/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py +++ b/models/experimental/functional_bloom/tt/ttnn_optimized_functional_bloom.py @@ -208,11 +208,10 @@ def bloom( causal_mask, parameters, num_heads, - hidden_layers, ): inputs_embeds = ttnn.embedding( input_ids, - parameters["transformer.word_embeddings.weight"], + parameters.transformer.word_embeddings.weight, layout=ttnn.TILE_LAYOUT, ) @@ -221,19 +220,19 @@ def bloom( hidden_states = ttnn.experimental.layer_norm( inputs_embeds, - weight=parameters[f"transformer.word_embeddings_layernorm.weight"], - bias=parameters[f"transformer.word_embeddings_layernorm.bias"], + weight=parameters.transformer.word_embeddings_layernorm.weight, + bias=parameters.transformer.word_embeddings_layernorm.bias, memory_config=BLOOM_MEMORY_CONFIG, ) ttnn.deallocate(inputs_embeds) if BLOOM_MEMORY_CONFIG == ttnn.L1_MEMORY_CONFIG: hidden_states = ttnn.reallocate(hidden_states) - for i in range(0, hidden_layers): + for layer_parameters in parameters.transformer.h: normalized_hidden_states = ttnn.experimental.layer_norm( hidden_states, - weight=parameters[f"transformer.h.{i}.input_layernorm.weight"], - bias=parameters[f"transformer.h.{i}.input_layernorm.bias"], + weight=layer_parameters.input_layernorm.weight, + bias=layer_parameters.input_layernorm.bias, memory_config=BLOOM_MEMORY_CONFIG, ) @@ -241,10 +240,10 @@ def bloom( normalized_hidden_states, alibi, causal_mask, - parameters[f"transformer.h.{i}.self_attention.query_key_value.weight"], - parameters[f"transformer.h.{i}.self_attention.query_key_value.bias"], - parameters[f"transformer.h.{i}.self_attention.dense.weight"], - parameters[f"transformer.h.{i}.self_attention.dense.bias"], + layer_parameters.self_attention.query_key_value.weight, + layer_parameters.self_attention.query_key_value.bias, + layer_parameters.self_attention.dense.weight, + layer_parameters.self_attention.dense.bias, head_size=head_size, ) ttnn.deallocate(normalized_hidden_states) @@ -254,17 +253,17 @@ def bloom( normalized_attention_output = ttnn.experimental.layer_norm( attention_output, - weight=parameters[f"transformer.h.{i}.post_attention_layernorm.weight"], - bias=parameters[f"transformer.h.{i}.post_attention_layernorm.bias"], + weight=layer_parameters.post_attention_layernorm.weight, + bias=layer_parameters.post_attention_layernorm.bias, memory_config=BLOOM_MEMORY_CONFIG, ) mlp_output = mlp( normalized_attention_output, - parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"], - parameters[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"], - parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"], - parameters[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"], + layer_parameters.mlp.dense_h_to_4h.weight, + layer_parameters.mlp.dense_h_to_4h.bias, + layer_parameters.mlp.dense_4h_to_h.weight, + layer_parameters.mlp.dense_4h_to_h.bias, ) ttnn.deallocate(normalized_attention_output) @@ -278,30 +277,30 @@ def bloom( hidden_states = ttnn.experimental.layer_norm( hidden_states, - weight=parameters[f"transformer.ln_f.weight"], - bias=parameters[f"transformer.ln_f.bias"], + weight=parameters.transformer.ln_f.weight, + bias=parameters.transformer.ln_f.bias, ) return hidden_states -def bloom_for_causal_lm(input_ids, alibi, casual_mask, parameters, num_heads, hidden_layers): - hidden_states = bloom(input_ids, alibi, casual_mask, parameters, num_heads, hidden_layers) +def bloom_for_causal_lm(input_ids, alibi, casual_mask, parameters, num_heads): + hidden_states = bloom(input_ids, alibi, casual_mask, parameters, num_heads) # Unfortuntely we do not have the ability to handle large tensors yet. So running final matmul ising torch is a workaround. hidden_states = ttnn.from_device(hidden_states) hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT) hidden_states = ttnn.to_torch(hidden_states).to(torch.float32) - output = hidden_states @ parameters[f"lm_head.weight"] + output = hidden_states @ parameters.lm_head.weight return output -def bloom_for_question_answering(input_ids, alibi, casual_mask, parameters, num_heads, hidden_layers): - hidden_states = bloom(input_ids, alibi, casual_mask, parameters, num_heads, hidden_layers) +def bloom_for_question_answering(input_ids, alibi, casual_mask, parameters, num_heads): + hidden_states = bloom(input_ids, alibi, casual_mask, parameters, num_heads) hidden_states = ttnn.linear( hidden_states, - parameters[f"qa_outputs.weight"], - bias=parameters[f"qa_outputs.bias"], + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, memory_config=BLOOM_MEMORY_CONFIG, ) return hidden_states @@ -346,7 +345,7 @@ def preprocess_inputs( return padded_input_ids, alibi, causal_mask -def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): +def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.bloom.modeling_bloom.BloomAttention): weight = torch_model.query_key_value.weight @@ -375,17 +374,11 @@ def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): value = torch.reshape(value, (hidden_size,)) preprocessed_bias = torch.cat([query, key, value], dim=0) - parameters[f"{full_name}query_key_value.weight"] = preprocess_linear_weight( - preprocessed_weight, dtype=parameters_config.linear_weight_dtype - ) - parameters[f"{full_name}query_key_value.bias"] = preprocess_linear_bias( - preprocessed_bias, dtype=parameters_config.linear_bias_dtype - ) + parameters = {"query_key_value": {}, "dense": {}} - parameters[f"{full_name}dense.weight"] = preprocess_linear_weight( - torch_model.dense.weight, dtype=parameters_config.linear_weight_dtype - ) - parameters[f"{full_name}dense.bias"] = preprocess_linear_bias( - torch_model.dense.bias, dtype=parameters_config.linear_bias_dtype - ) + parameters["query_key_value"]["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat16) + + parameters["dense"]["weight"] = preprocess_linear_weight(torch_model.dense.weight, dtype=ttnn.bfloat16) + parameters["dense"]["bias"] = preprocess_linear_bias(torch_model.dense.bias, dtype=ttnn.bfloat16) return parameters diff --git a/tests/ttnn/integration_tests/bert/test_bert.py b/tests/ttnn/integration_tests/bert/test_bert.py index 2f0c86bf818..2435f5a9cb4 100644 --- a/tests/ttnn/integration_tests/bert/test_bert.py +++ b/tests/ttnn/integration_tests/bert/test_bert.py @@ -25,7 +25,6 @@ torch_bert_for_question_answering, ) from ttnn.model_preprocessing import ( - ParametersConfig, preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight, @@ -61,11 +60,11 @@ def ttnn_bert_preprocess_inputs( return input_ids, token_type_ids, attention_mask -def is_to_be_converted(torch_model, full_name): +def convert_to_ttnn(torch_model, full_name): return True -def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): +def custom_preprocessor(torch_model, name): parameters = {} if isinstance(torch_model, transformers.models.bert.modeling_bert.BertSelfAttention): qkv_weight = torch.cat( @@ -80,51 +79,49 @@ def custom_preprocessor(parameters_config, torch_model, full_name, **kwargs): [torch_model.query.bias, torch_model.key.bias, torch_model.value.bias], dim=0, ) - parameters[f"{full_name}fused_qkv.weight"] = preprocess_linear_weight( - qkv_weight, dtype=parameters_config.linear_weight_dtype - ) - parameters[f"{full_name}fused_qkv.bias"] = preprocess_linear_bias( - qkv_bias, dtype=parameters_config.linear_bias_dtype - ) + + parameters = {"fused_qkv": {}} + parameters["fused_qkv"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat16) + parameters["fused_qkv"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat16) return parameters def run_bert_question_and_answering_inference(model_name, batch_size, sequence_size, use_optimized_version, device): torch.manual_seed(1234) - torch_model = transformers.BertForQuestionAnswering.from_pretrained(model_name, torchscript=False).eval() - config = torch_model.config + config = transformers.BertConfig.from_pretrained(model_name) - num_encoders = config.num_hidden_layers head_size = config.hidden_size // config.num_attention_heads # TODO(arakhmati): re-enable the line below once the issue with ttnn.embedding is fixed - # torch_bert_input = torch.randint(0, torch_model.config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + # torch_bert_input = torch.randint(0, config.config.vocab_size, (batch_size, sequence_size)).to(torch.int32) torch_bert_input = torch.randint(0, 1, (batch_size, sequence_size)).to(torch.int32) torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) torch_attention_mask = torch.zeros(1, sequence_size) if use_optimized_version else None + torch_parameters = preprocess_model_parameters( + f"torch-{model_name}", + initialize_model=lambda: transformers.BertForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ).eval(), + convert_to_ttnn=lambda *_: False, + ) + torch_output = torch_bert_for_question_answering( torch_bert_input, torch_token_type_ids, torch_attention_mask, - parameters=torch_model.state_dict(), - num_encoders=num_encoders, + parameters=torch_parameters, head_size=head_size, ) # Run TT Model - parameters_config = ParametersConfig( - linear_weight_dtype=ttnn.bfloat16, - linear_bias_dtype=ttnn.bfloat16, - layernorm_parameter_dtype=ttnn.bfloat16, - ) parameters = preprocess_model_parameters( - f"{model_name}-{use_optimized_version}", - "version_0", - parameters_config, - initialize_model=lambda: torch_model, - is_to_be_converted=is_to_be_converted, + f"ttnn-{model_name}-{use_optimized_version}", + initialize_model=lambda: transformers.BertForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ).eval(), + convert_to_ttnn=convert_to_ttnn, custom_preprocessor=custom_preprocessor if use_optimized_version else None, device=device, ) @@ -146,7 +143,6 @@ def run_bert_question_and_answering_inference(model_name, batch_size, sequence_s tt_output = bert_for_question_answering( *ttnn_bert_inputs, parameters=parameters, - num_encoders=num_encoders, head_size=head_size, ) tt_output = ttnn.to_layout(tt_output, ttnn.ROW_MAJOR_LAYOUT) diff --git a/tests/ttnn/integration_tests/bloom/test_bloom_for_causal_lm.py b/tests/ttnn/integration_tests/bloom/test_bloom_for_causal_lm.py index 446737d94fb..329f7eaa3c3 100644 --- a/tests/ttnn/integration_tests/bloom/test_bloom_for_causal_lm.py +++ b/tests/ttnn/integration_tests/bloom/test_bloom_for_causal_lm.py @@ -13,13 +13,10 @@ from models.utility_functions import skip_for_wormhole_b0 import ttnn -from ttnn.model_preprocessing import ( - preprocess_model_parameters, - ParametersConfig, -) +from ttnn.model_preprocessing import preprocess_model_parameters -def generate_next_token(model, input_ids, parameters, num_heads, hidden_layers, logits_processor, max_length, **kwargs): +def generate_next_token(model, input_ids, parameters, num_heads, logits_processor, max_length, **kwargs): num_tokens = input_ids.shape[-1] padded_input_ids, alibi, causal_mask = model.preprocess_inputs( input_ids=input_ids, @@ -29,7 +26,7 @@ def generate_next_token(model, input_ids, parameters, num_heads, hidden_layers, **kwargs, ) - logits = model.bloom_for_causal_lm(padded_input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers) + logits = model.bloom_for_causal_lm(padded_input_ids, alibi, causal_mask, parameters, num_heads) next_token_logits = logits[:, num_tokens - 1, :] # Get the logits for the last token processed_logits = logits_processor(input_ids, next_token_logits) next_token = torch.argmax(processed_logits, dim=-1).unsqueeze(-1) @@ -43,7 +40,6 @@ def generate_text( tokenizer, logits_processor, num_heads, - hidden_layers, num_tokens_to_decode, max_length=384, **kwargs, @@ -56,7 +52,6 @@ def generate_text( input_ids, parameters, num_heads, - hidden_layers, logits_processor, max_length, **kwargs, @@ -77,16 +72,19 @@ def test_torch_bloom_for_causal_lm(): model_name = "bigscience/bloom-560m" config = BloomConfig.from_pretrained(model_name) tokenizer = BloomTokenizerFast.from_pretrained(model_name) - model = BloomForCausalLM.from_pretrained(model_name).eval() input_text = "Hello, my dog is cute" expected_generated_text = "Hello, my dog is cute. He is a little shy, but he loves" # Initialize logits processor based on the model's configuration num_heads = config.n_head - hidden_layers = config.n_layer - parameters = torch_functional_bloom.preprocess_parameters(model.state_dict(), num_heads) + parameters = preprocess_model_parameters( + f"torch-functional-bloom-for-causal-lm", + initialize_model=lambda: BloomForCausalLM.from_pretrained(model_name).eval(), + custom_preprocessor=torch_functional_bloom.custom_preprocessor, + convert_to_ttnn=lambda *_: False, + ) input_ids = tokenizer.encode(input_text, return_tensors="pt") logits_processor = generation_utils.get_logits_processor(input_ids, config) @@ -98,7 +96,6 @@ def test_torch_bloom_for_causal_lm(): tokenizer, logits_processor, num_heads, - hidden_layers, num_tokens_to_decode=10, ) assert expected_generated_text == generated_text @@ -109,28 +106,19 @@ def test_ttnn_bloom_for_causal_lm(device, batch_size=8): model_name = "bigscience/bloom-560m" config = BloomConfig.from_pretrained(model_name) tokenizer = BloomTokenizerFast.from_pretrained(model_name) - model = BloomForCausalLM.from_pretrained(model_name).eval() input_text = "Hello, my dog is cute" expected_generated_text = "Hello, my dog is cute and sweet. He loves to play with me and" num_heads = config.n_head - hidden_layers = config.n_layer - parameters_config = ParametersConfig( - linear_weight_dtype=ttnn.bfloat16, - linear_bias_dtype=ttnn.bfloat16, - layernorm_parameter_dtype=ttnn.bfloat16, - ) parameters = preprocess_model_parameters( f"ttnn-functional-bloom-for-causal-lm", - "version_0", - parameters_config, - initialize_model=lambda: model, + initialize_model=lambda: BloomForCausalLM.from_pretrained(model_name).eval(), device=device, custom_preprocessor=ttnn_optimized_functional_bloom.custom_preprocessor, + convert_to_ttnn=lambda model, name: name != "lm_head", ) - parameters[f"lm_head.weight"] = model.state_dict()[f"lm_head.weight"].T.to(torch.float32) # Initialize logits processor based on the model's configuration input_ids = tokenizer.encode(input_text, return_tensors="pt") @@ -144,7 +132,6 @@ def test_ttnn_bloom_for_causal_lm(device, batch_size=8): tokenizer, logits_processor, num_heads, - hidden_layers, num_tokens_to_decode=10, device=device, ) diff --git a/tests/ttnn/integration_tests/bloom/test_bloom_for_question_answering.py b/tests/ttnn/integration_tests/bloom/test_bloom_for_question_answering.py index eaaf7d38aee..0a145b6af43 100644 --- a/tests/ttnn/integration_tests/bloom/test_bloom_for_question_answering.py +++ b/tests/ttnn/integration_tests/bloom/test_bloom_for_question_answering.py @@ -15,10 +15,7 @@ from models.utility_functions import skip_for_wormhole_b0 import ttnn -from ttnn.model_preprocessing import ( - preprocess_model_parameters, - ParametersConfig, -) +from ttnn.model_preprocessing import preprocess_model_parameters from tests.ttnn.utils_for_testing import assert_with_pcc @@ -34,7 +31,6 @@ def test_pcc_of_bloom_for_question_answering(device, use_program_cache, ttnn_mod torch_model = BloomForQuestionAnswering.from_pretrained(model_name).eval() num_heads = config.n_head - hidden_layers = config.n_layer question = "What is my name?" context = "My name is John." @@ -44,15 +40,8 @@ def test_pcc_of_bloom_for_question_answering(device, use_program_cache, ttnn_mod torch_start_logits = torch_output.start_logits torch_end_logits = torch_output.end_logits - parameters_config = ParametersConfig( - linear_weight_dtype=ttnn.bfloat16, - linear_bias_dtype=ttnn.bfloat16, - layernorm_parameter_dtype=ttnn.bfloat16, - ) parameters = preprocess_model_parameters( f"ttnn-functional-bloom-for-question-answering", - "version_0", - parameters_config, initialize_model=lambda: torch_model, device=device, custom_preprocessor=ttnn_model.custom_preprocessor, @@ -70,9 +59,7 @@ def test_pcc_of_bloom_for_question_answering(device, use_program_cache, ttnn_mod ) # Run twice to measure the time with and without the program cache - tt_output = ttnn_model.bloom_for_question_answering( - input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers - ) + tt_output = ttnn_model.bloom_for_question_answering(input_ids, alibi, causal_mask, parameters, num_heads) tt_output = ttnn.from_device(tt_output) tt_output = ttnn.to_layout(tt_output, ttnn.ROW_MAJOR_LAYOUT) @@ -107,21 +94,13 @@ def test_performance_of_bloom_for_question_answering( tokenizer = BloomTokenizerFast.from_pretrained(model_name) num_heads = config.n_head - hidden_layers = config.n_layer question = "What is my name?" context = "My name is John." inputs = tokenizer.encode_plus(question, context, return_tensors="pt") - parameters_config = ParametersConfig( - linear_weight_dtype=ttnn.bfloat16, - linear_bias_dtype=ttnn.bfloat16, - layernorm_parameter_dtype=ttnn.bfloat16, - ) parameters = preprocess_model_parameters( "ttnn-functional-bloom-for-question-answering", - "version_0", - parameters_config, initialize_model=lambda: BloomForQuestionAnswering.from_pretrained(model_name).eval(), device=device, custom_preprocessor=ttnn_model.custom_preprocessor, @@ -141,9 +120,7 @@ def test_performance_of_bloom_for_question_answering( # Run twice to measure the time with and without the program cache for _ in range(2): start = time.time() - tt_output = ttnn_model.bloom_for_question_answering( - input_ids, alibi, causal_mask, parameters, num_heads, hidden_layers - ) + tt_output = ttnn_model.bloom_for_question_answering(input_ids, alibi, causal_mask, parameters, num_heads) tt_output = ttnn.from_device(tt_output) end = time.time() diff --git a/ttnn/model_preprocessing.py b/ttnn/model_preprocessing.py index 92a7b4b6269..fb43de592e3 100644 --- a/ttnn/model_preprocessing.py +++ b/ttnn/model_preprocessing.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +import copy +import io import pathlib import shutil -from typing import Optional +from typing import Optional, Union from loguru import logger import numpy as np @@ -17,14 +18,6 @@ TILE_WIDTH = 32 -@dataclass -class ParametersConfig: - linear_weight_dtype: ttnn.DataType = ttnn.bfloat16 - linear_bias_dtype: ttnn.DataType = ttnn.bfloat16 - layernorm_parameter_dtype: ttnn.DataType = ttnn.bfloat16 - embedding_weight_dtype: ttnn.DataType = ttnn.bfloat16 - - def pad_tensor(tensor, height_multiple=TILE_HEIGHT, width_multiple=TILE_WIDTH): if len(tensor.shape) > 1: *_, height, width = tensor.shape @@ -80,188 +73,273 @@ def preprocess_embedding_weight(weight, *, dtype): return weight -def default_preprocessor(parameters_config: ParametersConfig, torch_model, full_name): +class ParameterList(list): + def __repr__(self): + file = io.StringIO() + repr_parameters(file, self) + return file.getvalue() + + +class ParameterDict(dict): + __getattr__ = dict.__getitem__ + __delattr__ = dict.__delitem__ + + def __repr__(self): + file = io.StringIO() + repr_parameters(file, self) + return file.getvalue() + + +def make_parameter_dict(dictionary: Union[dict, ParameterDict]) -> ParameterDict: + if isinstance(dictionary, ParameterDict): + return dictionary + preprocessed_dictionary = {} + for key, value in dictionary.items(): + if isinstance(value, dict): + value = make_parameter_dict(value) + preprocessed_dictionary[key] = value + return ParameterDict(preprocessed_dictionary) + + +def repr_parameters(file, parameters, indentation=""): + next_indentation = indentation + " " + if isinstance(parameters, ParameterDict): + if not parameters: + file.write("{}") + return + + file.write("{\n") + for index, (key, value) in enumerate(parameters.items()): + file.write(next_indentation) + file.write(f"{key}: ") + repr_parameters(file, value, next_indentation) + file.write(",\n" if index < len(parameters) - 1 else "\n") + file.write(indentation) + file.write("}") + elif isinstance(parameters, ParameterList): + if not parameters: + file.write("[]") + return + + file.write("[\n") + for index, element in enumerate(parameters): + file.write(next_indentation) + repr_parameters(file, element, next_indentation) + file.write(",\n" if index < len(parameters) - 1 else "\n") + file.write(indentation) + file.write("]") + else: + file.write(repr(parameters.shape)) + + +def default_preprocessor(model, name) -> ParameterDict: parameters = {} - if isinstance(torch_model, torch.nn.Linear): - parameters[f"{full_name}weight"] = preprocess_linear_weight( - torch_model.weight, dtype=parameters_config.linear_weight_dtype - ) - if torch_model.bias is not None: - parameters[f"{full_name}bias"] = preprocess_linear_bias( - torch_model.bias, dtype=parameters_config.linear_bias_dtype - ) - elif isinstance(torch_model, torch.nn.LayerNorm): - parameters[f"{full_name}weight"] = preprocess_layernorm_parameter( - torch_model.weight, dtype=parameters_config.layernorm_parameter_dtype - ) - parameters[f"{full_name}bias"] = preprocess_layernorm_parameter( - torch_model.bias, dtype=parameters_config.layernorm_parameter_dtype - ) - elif isinstance(torch_model, torch.nn.Embedding): - parameters[f"{full_name}weight"] = preprocess_embedding_weight( - torch_model.weight, dtype=parameters_config.embedding_weight_dtype - ) - return parameters + if isinstance(model, torch.nn.Linear): + parameters[f"weight"] = preprocess_linear_weight(model.weight, dtype=ttnn.bfloat16) + if model.bias is not None: + parameters[f"bias"] = preprocess_linear_bias(model.bias, dtype=ttnn.bfloat16) + elif isinstance(model, torch.nn.LayerNorm): + parameters[f"weight"] = preprocess_layernorm_parameter(model.weight, dtype=ttnn.bfloat16) + parameters[f"bias"] = preprocess_layernorm_parameter(model.bias, dtype=ttnn.bfloat16) + elif isinstance(model, torch.nn.Embedding): + parameters[f"weight"] = preprocess_embedding_weight(model.weight, dtype=ttnn.bfloat16) + return make_parameter_dict(parameters) def _preprocess_model_parameters( - parameters_config, - torch_model, + model, *, - prefix="", - is_to_be_converted, + convert_to_ttnn, custom_preprocessor=None, -): - parameters = {} - - named_children = list(torch_model.named_children()) - - if not named_children: - for name, parameter in torch_model.named_parameters(): - full_name = f"{prefix}{name}" - parameters[full_name] = parameter - - for name, child in named_children: - full_name = f"{prefix}{name}." - - use_default_preprocessor = True - if custom_preprocessor is not None: - custom_preprocessor_parameters = custom_preprocessor( - parameters_config=parameters_config, - torch_model=child, - full_name=full_name, - ) - if custom_preprocessor_parameters: - parameters.update(custom_preprocessor_parameters) - # Custom preprocessor didn't handle this case, so, try using default preprocessor - use_default_preprocessor = False - - if use_default_preprocessor: - if not is_to_be_converted(child, full_name): - child_parameters = _preprocess_model_parameters( - parameters_config, + name="", +) -> ParameterDict: + if isinstance(model, torch.nn.modules.container.ModuleList): + return ParameterList( + [ + _preprocess_model_parameters( child, - prefix=full_name, - is_to_be_converted=is_to_be_converted, + convert_to_ttnn=convert_to_ttnn, custom_preprocessor=custom_preprocessor, + name=f"{name}.{index}" if name else f"{index}", ) - parameters.update(child_parameters) - else: - default_preprocessor_parameters = default_preprocessor(parameters_config, child, full_name) - if default_preprocessor_parameters: - parameters.update(default_preprocessor_parameters) - else: - child_parameters = _preprocess_model_parameters( - parameters_config, - child, - prefix=full_name, - is_to_be_converted=is_to_be_converted, - custom_preprocessor=custom_preprocessor, - ) - parameters.update(child_parameters) + for index, child in enumerate(model.children()) + ] + ) - return parameters + if custom_preprocessor is not None: + custom_preprocessor_parameters = custom_preprocessor(model, name) + if custom_preprocessor_parameters: + return make_parameter_dict(custom_preprocessor_parameters) + + if convert_to_ttnn(model, name): + default_preprocessor_parameters = default_preprocessor(model, name) + if default_preprocessor_parameters: + return make_parameter_dict(default_preprocessor_parameters) + if isinstance(model, torch.nn.Linear): + # TODO: remove deepcopy. It's needed because we don't want to modify the actual model + model = copy.deepcopy(model) + model.weight = torch.nn.Parameter(model.weight.T.contiguous()) + elif isinstance(model, torch.nn.Conv2d): + raise RuntimeError("Transpose conv weights?") + + named_children = list(model.named_children()) + if not named_children: + return make_parameter_dict(dict(model.named_parameters())) -def _load_parameters(model_cache_path: pathlib.Path) -> dict: parameters = {} - for file_name in model_cache_path.glob("*"): - if file_name.name == "version.txt": + for child_name, child in named_children: + parameters[child_name] = _preprocess_model_parameters( + child, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=custom_preprocessor, + name=f"{name}.{child_name}" if name else child_name, + ) + + parameters = make_parameter_dict(parameters) + + return parameters + + +def _load_parameters(model_cache_path: pathlib.Path) -> ParameterDict: + output = {} + for path in model_cache_path.glob("*"): + if path.name == "version.txt": continue - extension = file_name.suffix - name = file_name.stem - if extension == ".bin": - parameters[name] = ttnn.load_tensor(file_name) + extension = path.suffix + name = path.stem + + if path.is_dir(): + parameters = _load_parameters(path) + if all(str(key).isdigit() for key in parameters): + parameters = {int(key): value for key, value in parameters.items()} + parameters = ParameterList([parameters[index] for index in sorted(parameters.keys())]) + output[name] = parameters + elif extension == ".bin": + output[name] = ttnn.load_tensor(path) elif extension == ".pt": - parameters[name] = torch.load(file_name) + output[name] = torch.load(path) else: raise RuntimeError("Unrecognized file extension!") - return parameters + return ParameterDict(output) -def _dump_parameters(model_cache_path: pathlib.Path, parameters: dict) -> None: +def _dump_parameters(model_cache_path: pathlib.Path, parameters: ParameterDict) -> None: model_cache_path.mkdir(parents=True) - for name, tensor in parameters.items(): - file_path = str(model_cache_path / name) - if isinstance(tensor, ttnn.Tensor): + for name, value in parameters.items(): + if isinstance(value, ParameterDict): + _dump_parameters(model_cache_path / name, value) + elif isinstance(value, ParameterList): + for index, element in enumerate(value): + _dump_parameters(model_cache_path / name / str(index), element) + elif isinstance(value, ttnn.Tensor): + file_path = str(model_cache_path / name) file_name = file_path + ".bin" - ttnn.dump_tensor(file_name, tensor) - elif isinstance(tensor, torch.nn.Parameter): + ttnn.dump_tensor(file_name, value) + elif isinstance(value, (torch.Tensor, torch.nn.Parameter)): + file_path = str(model_cache_path / name) file_name = file_path + ".pt" - torch.save(tensor, file_name) + torch.save(value, file_name) else: - raise RuntimeError(f"Unsupported type: {type(tensor)}") + raise RuntimeError(f"Unsupported type: {type(value)}") + + +def move_to_device(parameters, device): + for name, value in list(parameters.items()): + if isinstance(value, ParameterDict): + parameters[name] = move_to_device(value, device) + elif isinstance(value, ParameterList): + for index, element in enumerate(value): + parameters[name][index] = move_to_device(element, device) + elif isinstance(value, ttnn.Tensor): + parameters[name] = ttnn.to_device(value, device) + else: + parameters[name] = value + return parameters + + +def git_hash(): + try: + import subprocess + + return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() + except Exception as e: + raise RuntimeError("Couldn't get git hash!") from e def preprocess_model_parameters( - model_name, - version, - parameters_config, + model_name=None, + version=None, *, initialize_model, - prefix="", - is_to_be_converted=None, + convert_to_ttnn=None, custom_preprocessor=None, device: Optional[ttnn.Device] = None, -): - model_cache_path = ttnn.MODEL_CACHE_PATH / model_name - version_file_path = model_cache_path / "version.txt" - - cache_exists = model_cache_path.exists() - if cache_exists: - if version_file_path.exists(): - with open(version_file_path) as f: - cached_version = f.readline() - else: - cached_version = None +) -> ParameterDict: + if convert_to_ttnn is None: - version_matches = version == cached_version - else: - version_matches = False + def convert_to_ttnn(model, full_name): + return True + + if model_name is None: + model = initialize_model() + parameters = _preprocess_model_parameters( + model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=custom_preprocessor, + ) - if cache_exists and version_matches: - logger.info(f'Loading model weights from cache: {model_cache_path} (version "{version}")') - parameters = _load_parameters(model_cache_path) - logger.info(f'Loaded model weights from cache: {model_cache_path} (version "{version}")') else: - if initialize_model is None: - raise RuntimeError(f'Cached weights for the model {model_name} (version "{version}") don\'t exist') + model_cache_path = ttnn.MODEL_CACHE_PATH / model_name + version_file_path = model_cache_path / "version.txt" - logger.info(f'Saving model weights to cache: {model_cache_path} (version "{version}")') + if version is None: + version = git_hash() - if is_to_be_converted is None: + cache_exists = model_cache_path.exists() + if cache_exists: + if version_file_path.exists(): + with open(version_file_path) as f: + cached_version = f.readline() + else: + cached_version = None - def is_to_be_converted(*args, **kwargs): - return True + version_matches = version == cached_version + else: + version_matches = False - torch_model = initialize_model() - parameters = _preprocess_model_parameters( - parameters_config, - torch_model, - prefix=prefix, - is_to_be_converted=is_to_be_converted, - custom_preprocessor=custom_preprocessor, - ) + if cache_exists and version_matches: + logger.info(f'Loading model weights from cache: {model_cache_path} (version "{version}")') + parameters = _load_parameters(model_cache_path) + logger.info(f'Loaded model weights from cache: {model_cache_path} (version "{version}")') + else: + if initialize_model is None: + raise RuntimeError(f'Cached weights for the model {model_name} (version "{version}") don\'t exist') + + logger.info(f'Saving model weights to cache: {model_cache_path} (version "{version}")') - # TODO: use temporary directory - if model_cache_path.exists(): - shutil.rmtree(model_cache_path) + model = initialize_model() + parameters = _preprocess_model_parameters( + model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=custom_preprocessor, + ) + + # TODO: use temporary directory + if model_cache_path.exists(): + shutil.rmtree(model_cache_path) - _dump_parameters(model_cache_path, parameters) + _dump_parameters(model_cache_path, parameters) - with open(version_file_path, "w") as f: - f.write(version) + with open(version_file_path, "w") as f: + f.write(version) - logger.info(f'Saved model weights to cache: {model_cache_path} (version "{version}")') + logger.info(f'Saved model weights to cache: {model_cache_path} (version "{version}")') if device is not None: - logger.info(f'Moving model weights to device: {model_cache_path} (version "{version}")') - for name, parameter in list(parameters.items()): - if isinstance(parameter, ttnn.Tensor): - parameters[name] = ttnn.to_device(parameter, device) - else: - parameters[name] = parameter - logger.info(f'Moved model weights to device: {model_cache_path} (version "{version}")') + logger.info(f"Moving model weights to device") + parameters = move_to_device(parameters, device) + logger.info(f"Moved model weights to device") return parameters