Skip to content

Commit

Permalink
#4003: updated ttnn.model_preprocessing to keep the structure of the …
Browse files Browse the repository at this point in the history
…model weights
  • Loading branch information
arakhmati committed Dec 7, 2023
1 parent 657d51c commit 88dba1a
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 447 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,44 +70,43 @@ def torch_bert_encoder(
attention_mask,
parameters,
*,
encoder_index,
head_size,
):
*_, hidden_size = hidden_states.shape
multi_head_attention_output = torch_multi_head_attention(
hidden_states,
attention_mask,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"],
parameters.attention.self.query.weight,
parameters.attention.self.query.bias,
parameters.attention.self.key.weight,
parameters.attention.self.key.bias,
parameters.attention.self.value.weight,
parameters.attention.self.value.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
)

multi_head_attention_add_and_layer_norm_output = F.layer_norm(
hidden_states + multi_head_attention_output,
(hidden_size,),
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"],
parameters.attention.output.LayerNorm.weight,
parameters.attention.output.LayerNorm.bias,
)

feedforward_output = torch_feedforward(
multi_head_attention_add_and_layer_norm_output,
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"].T,
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"],
parameters.intermediate.dense.weight,
parameters.intermediate.dense.bias,
parameters.output.dense.weight,
parameters.output.dense.bias,
)

feedforward_add_and_layer_norm_output = F.layer_norm(
multi_head_attention_add_and_layer_norm_output + feedforward_output,
(hidden_size,),
parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"],
parameters.output.LayerNorm.weight,
parameters.output.LayerNorm.bias,
)

return feedforward_add_and_layer_norm_output
Expand All @@ -119,28 +118,26 @@ def torch_bert(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
word_embeddings = F.embedding(input_ids, parameters["bert.embeddings.word_embeddings.weight"])
token_type_embeddings = F.embedding(token_type_ids, parameters["bert.embeddings.token_type_embeddings.weight"])
word_embeddings = F.embedding(input_ids, parameters.bert.embeddings.word_embeddings.weight)
token_type_embeddings = F.embedding(token_type_ids, parameters.bert.embeddings.token_type_embeddings.weight)
encoder_input = word_embeddings + token_type_embeddings

*_, hidden_size = encoder_input.shape
encoder_input = F.layer_norm(
encoder_input,
(hidden_size,),
parameters["bert.embeddings.LayerNorm.weight"],
parameters["bert.embeddings.LayerNorm.bias"],
parameters.bert.embeddings.LayerNorm.weight,
parameters.bert.embeddings.LayerNorm.bias,
)

encoder_output = None
for encoder_index in range(num_encoders):
for encoder_parameters in parameters.bert.encoder.layer:
encoder_output = torch_bert_encoder(
encoder_input,
attention_mask,
parameters,
encoder_index=encoder_index,
encoder_parameters,
head_size=head_size,
)
encoder_input = encoder_output
Expand All @@ -153,19 +150,17 @@ def torch_bert_for_question_answering(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
bert_output = torch_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
num_encoders=num_encoders,
head_size=head_size,
)

qa_outputs = bert_output
qa_outputs = qa_outputs @ parameters["qa_outputs.weight"].T
qa_outputs = qa_outputs + parameters["qa_outputs.bias"]
qa_outputs = qa_outputs @ parameters.qa_outputs.weight
qa_outputs = qa_outputs + parameters.qa_outputs.bias
return qa_outputs
53 changes: 24 additions & 29 deletions models/experimental/functional_bert/tt/ttnn_functional_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,41 +75,40 @@ def ttnn_bert_encoder(
attention_mask,
parameters,
*,
encoder_index,
head_size,
):
multi_head_attention_output = ttnn_multi_head_attention(
hidden_states,
attention_mask,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.query.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.key.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.value.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"],
parameters.attention.self.query.weight,
parameters.attention.self.query.bias,
parameters.attention.self.key.weight,
parameters.attention.self.key.bias,
parameters.attention.self.value.weight,
parameters.attention.self.value.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
)

hidden_states = ttnn.experimental.layer_norm(
hidden_states + multi_head_attention_output,
weight=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"],
bias=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"],
weight=parameters.attention.output.LayerNorm.weight,
bias=parameters.attention.output.LayerNorm.bias,
)

feedforward_output = ttnn_feedforward(
hidden_states,
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"],
parameters.intermediate.dense.weight,
parameters.intermediate.dense.bias,
parameters.output.dense.weight,
parameters.output.dense.bias,
)

hidden_states = ttnn.experimental.layer_norm(
hidden_states + feedforward_output,
weight=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"],
bias=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"],
weight=parameters.output.LayerNorm.weight,
bias=parameters.output.LayerNorm.bias,
)

return hidden_states
Expand All @@ -121,30 +120,28 @@ def ttnn_bert(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
word_embeddings = ttnn.embedding(
input_ids, parameters["bert.embeddings.word_embeddings.weight"], layout=ttnn.TILE_LAYOUT
input_ids, parameters.bert.embeddings.word_embeddings.weight, layout=ttnn.TILE_LAYOUT
)
token_type_embeddings = ttnn.embedding(
token_type_ids, parameters["bert.embeddings.token_type_embeddings.weight"], layout=ttnn.TILE_LAYOUT
token_type_ids, parameters.bert.embeddings.token_type_embeddings.weight, layout=ttnn.TILE_LAYOUT
)
encoder_input = word_embeddings + token_type_embeddings

encoder_input = ttnn.experimental.layer_norm(
encoder_input,
weight=parameters[f"bert.embeddings.LayerNorm.weight"],
bias=parameters[f"bert.embeddings.LayerNorm.bias"],
weight=parameters.bert.embeddings.LayerNorm.weight,
bias=parameters.bert.embeddings.LayerNorm.bias,
)

encoder_output = None
for encoder_index in range(num_encoders):
for encoder_parameters in parameters.bert.encoder.layer:
encoder_output = ttnn_bert_encoder(
encoder_input,
attention_mask,
parameters,
encoder_index=encoder_index,
encoder_parameters,
head_size=head_size,
)
encoder_input = encoder_output
Expand All @@ -157,20 +154,18 @@ def ttnn_bert_for_question_answering(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
bert_output = ttnn_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
num_encoders=num_encoders,
head_size=head_size,
)

qa_outputs = bert_output
qa_outputs = qa_outputs @ parameters["qa_outputs.weight"]
qa_outputs = qa_outputs + parameters["qa_outputs.bias"]
qa_outputs = qa_outputs @ parameters.qa_outputs.weight
qa_outputs = qa_outputs + parameters.qa_outputs.bias

return qa_outputs
Original file line number Diff line number Diff line change
Expand Up @@ -110,42 +110,41 @@ def ttnn_optimized_bert_encoder(
attention_mask,
parameters,
*,
encoder_index,
head_size,
):
multi_head_attention_output = ttnn_optimized_multi_head_attention(
hidden_states,
attention_mask,
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.fused_qkv.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.self.fused_qkv.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.attention.output.dense.bias"],
parameters.attention.self.fused_qkv.weight,
parameters.attention.self.fused_qkv.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
)

multi_head_attention_add_and_layer_norm_output = ttnn.experimental.layer_norm(
hidden_states,
residual_input=multi_head_attention_output,
weight=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.weight"],
bias=parameters[f"bert.encoder.layer.{encoder_index}.attention.output.LayerNorm.bias"],
weight=parameters.attention.output.LayerNorm.weight,
bias=parameters.attention.output.LayerNorm.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
ttnn.deallocate(hidden_states)
ttnn.deallocate(multi_head_attention_output)

feedforward_output = ttnn_optimized_feedforward(
multi_head_attention_add_and_layer_norm_output,
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.intermediate.dense.bias"],
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.weight"],
parameters[f"bert.encoder.layer.{encoder_index}.output.dense.bias"],
parameters.intermediate.dense.weight,
parameters.intermediate.dense.bias,
parameters.output.dense.weight,
parameters.output.dense.bias,
)

feedforward_add_and_layer_norm_output = ttnn.experimental.layer_norm(
multi_head_attention_add_and_layer_norm_output,
residual_input=feedforward_output,
weight=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.weight"],
bias=parameters[f"bert.encoder.layer.{encoder_index}.output.LayerNorm.bias"],
weight=parameters.output.LayerNorm.weight,
bias=parameters.output.LayerNorm.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
ttnn.deallocate(multi_head_attention_add_and_layer_norm_output)
Expand All @@ -160,21 +159,20 @@ def ttnn_optimized_bert(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
import tt_lib as ttl

word_embeddings = ttnn.embedding(
input_ids,
parameters["bert.embeddings.word_embeddings.weight"],
parameters.bert.embeddings.word_embeddings.weight,
layout=ttnn.TILE_LAYOUT,
)
ttnn.deallocate(input_ids)

token_type_embeddings = ttnn.embedding(
token_type_ids,
parameters["bert.embeddings.token_type_embeddings.weight"],
parameters.bert.embeddings.token_type_embeddings.weight,
layout=ttnn.TILE_LAYOUT,
)
ttnn.deallocate(token_type_ids)
Expand All @@ -185,19 +183,18 @@ def ttnn_optimized_bert(

encoder_input = ttnn.experimental.layer_norm(
embeddings,
weight=parameters[f"bert.embeddings.LayerNorm.weight"],
bias=parameters[f"bert.embeddings.LayerNorm.bias"],
weight=parameters.bert.embeddings.LayerNorm.weight,
bias=parameters.bert.embeddings.LayerNorm.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
ttnn.deallocate(embeddings)

encoder_output = None
for encoder_index in range(num_encoders):
for encoder_parameters in parameters.bert.encoder.layer:
encoder_output = ttnn_optimized_bert_encoder(
encoder_input,
attention_mask,
parameters,
encoder_index=encoder_index,
encoder_parameters,
head_size=head_size,
)
encoder_output = ttnn.reallocate(encoder_output)
Expand All @@ -212,22 +209,20 @@ def ttnn_optimized_bert_for_question_answering(
attention_mask,
parameters,
*,
num_encoders,
head_size,
):
bert_output = ttnn_optimized_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
num_encoders=num_encoders,
head_size=head_size,
)

qa_outputs = ttnn.linear(
bert_output,
parameters["qa_outputs.weight"],
bias=parameters["qa_outputs.bias"],
parameters.qa_outputs.weight,
bias=parameters.qa_outputs.bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
)

Expand Down
Loading

0 comments on commit 88dba1a

Please sign in to comment.