Skip to content

Commit

Permalink
#4003: added support for both t5-small in addition to google/flan-t5-…
Browse files Browse the repository at this point in the history
…small. Added funcitonal optimized t5. Added functional bert to perf models pipeline
  • Loading branch information
arakhmati authored and TT-billteng committed Dec 22, 2023
1 parent d477a42 commit cc05424
Show file tree
Hide file tree
Showing 40 changed files with 1,956 additions and 605 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def torch_multi_head_attention(
output_weight,
output_bias,
*,
head_size,
num_heads,
):
batch_size, sequence_size, hidden_size = hidden_states.shape
num_heads = hidden_size // head_size
head_size = hidden_size // num_heads

query = hidden_states @ query_weight
query = query + query_bias
Expand Down Expand Up @@ -70,7 +70,7 @@ def torch_bert_encoder(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
*_, hidden_size = hidden_states.shape
multi_head_attention_output = torch_multi_head_attention(
Expand All @@ -84,7 +84,7 @@ def torch_bert_encoder(
parameters.attention.self.value.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
num_heads=num_heads,
)

multi_head_attention_add_and_layer_norm_output = F.layer_norm(
Expand Down Expand Up @@ -118,7 +118,7 @@ def torch_bert(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
word_embeddings = F.embedding(input_ids, parameters.bert.embeddings.word_embeddings.weight)
token_type_embeddings = F.embedding(token_type_ids, parameters.bert.embeddings.token_type_embeddings.weight)
Expand All @@ -138,7 +138,7 @@ def torch_bert(
encoder_input,
attention_mask,
encoder_parameters,
head_size=head_size,
num_heads=num_heads,
)
encoder_input = encoder_output
return encoder_output
Expand All @@ -150,14 +150,14 @@ def torch_bert_for_question_answering(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
bert_output = torch_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
head_size=head_size,
num_heads=num_heads,
)

qa_outputs = bert_output
Expand Down
16 changes: 8 additions & 8 deletions models/experimental/functional_bert/tt/ttnn_functional_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def ttnn_multi_head_attention(
output_weight,
output_bias,
*,
head_size,
num_heads,
):
batch_size, sequence_size, hidden_size = hidden_states.shape
num_heads = hidden_size // head_size
head_size = hidden_size // num_heads

query = hidden_states @ query_weight
query = query + query_bias
Expand Down Expand Up @@ -75,7 +75,7 @@ def ttnn_bert_encoder(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
multi_head_attention_output = ttnn_multi_head_attention(
hidden_states,
Expand All @@ -88,7 +88,7 @@ def ttnn_bert_encoder(
parameters.attention.self.value.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
num_heads=num_heads,
)

hidden_states = ttnn.layer_norm(
Expand Down Expand Up @@ -120,7 +120,7 @@ def ttnn_bert(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
word_embeddings = ttnn.embedding(
input_ids, parameters.bert.embeddings.word_embeddings.weight, layout=ttnn.TILE_LAYOUT
Expand All @@ -142,7 +142,7 @@ def ttnn_bert(
encoder_input,
attention_mask,
encoder_parameters,
head_size=head_size,
num_heads=num_heads,
)
encoder_input = encoder_output
return encoder_output
Expand All @@ -154,14 +154,14 @@ def ttnn_bert_for_question_answering(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
bert_output = ttnn_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
head_size=head_size,
num_heads=num_heads,
)

qa_outputs = bert_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
def ttnn_optimized_multi_head_attention(
hidden_states,
attention_mask,
fused_qkv_weight,
fused_qkv_bias,
query_key_value_weight,
query_key_value_bias,
self_output_weight,
self_output_bias,
*,
head_size,
num_heads,
num_cores_x=12,
):
batch_size, *_ = hidden_states.shape
batch_size, _, hidden_size = hidden_states.shape
head_size = hidden_size // num_heads

fused_qkv_output = ttnn.linear(
query_key_value_output = ttnn.linear(
hidden_states,
fused_qkv_weight,
bias=fused_qkv_bias,
query_key_value_weight,
bias=query_key_value_bias,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
core_grid=(batch_size, num_cores_x),
Expand All @@ -32,11 +33,12 @@ def ttnn_optimized_multi_head_attention(
key,
value,
) = ttnn.nlp.split_query_key_value_and_split_heads(
fused_qkv_output,
query_key_value_output,
memory_config=ttnn.L1_MEMORY_CONFIG,
core_grid=(batch_size, num_cores_x),
num_heads=num_heads,
)
ttnn.deallocate(fused_qkv_output)
ttnn.deallocate(query_key_value_output)

attention_scores = ttnn.matmul(
query,
Expand Down Expand Up @@ -110,16 +112,16 @@ def ttnn_optimized_bert_encoder(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
multi_head_attention_output = ttnn_optimized_multi_head_attention(
hidden_states,
attention_mask,
parameters.attention.self.fused_qkv.weight,
parameters.attention.self.fused_qkv.bias,
parameters.attention.self.query_key_value.weight,
parameters.attention.self.query_key_value.bias,
parameters.attention.output.dense.weight,
parameters.attention.output.dense.bias,
head_size=head_size,
num_heads=num_heads,
)

multi_head_attention_add_and_layer_norm_output = ttnn.layer_norm(
Expand Down Expand Up @@ -159,7 +161,7 @@ def ttnn_optimized_bert(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
word_embeddings = ttnn.embedding(
input_ids,
Expand Down Expand Up @@ -193,7 +195,7 @@ def ttnn_optimized_bert(
encoder_input,
attention_mask,
encoder_parameters,
head_size=head_size,
num_heads=num_heads,
)
encoder_output = ttnn.reallocate(encoder_output)
encoder_input = encoder_output
Expand All @@ -207,14 +209,14 @@ def ttnn_optimized_bert_for_question_answering(
attention_mask,
parameters,
*,
head_size,
num_heads,
):
bert_output = ttnn_optimized_bert(
input_ids,
token_type_ids,
attention_mask,
parameters,
head_size=head_size,
num_heads=num_heads,
)

qa_outputs = ttnn.linear(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,24 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc


# From transformers/models/bloom/modeling_bloom.py
def split_heads(fused_qkv: torch.Tensor, head_size) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def split_heads(query_key_value: torch.Tensor, num_heads) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
storage as `query_key_value`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
query_key_value (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, sequence_size, three_times_hidden_size = fused_qkv.shape
batch_size, sequence_size, three_times_hidden_size = query_key_value.shape
hidden_size = three_times_hidden_size // 3
num_heads = hidden_size // head_size
head_size = hidden_size // num_heads

fused_qkv = fused_qkv.view(batch_size, sequence_size, 3, num_heads, head_size)
return fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :]
query_key_value = query_key_value.view(batch_size, sequence_size, 3, num_heads, head_size)
return query_key_value[..., 0, :, :], query_key_value[..., 1, :, :], query_key_value[..., 2, :, :]


# From transformers/models/bloom/modeling_bloom.py
Expand All @@ -92,17 +92,18 @@ def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))


def create_query_key_value(hidden_states, weight, bias, head_size):
fused_qkv = hidden_states @ weight
fused_qkv += bias
query_layer, key_layer, value_layer = split_heads(fused_qkv, head_size)
def create_query_key_value(hidden_states, weight, bias, num_heads):
query_key_value = hidden_states @ weight
query_key_value += bias
query_layer, key_layer, value_layer = split_heads(query_key_value, num_heads)
query_layer = torch.permute(query_layer, (0, 2, 1, 3))
key_layer = torch.permute(key_layer, (0, 2, 3, 1))
value_layer = torch.permute(value_layer, (0, 2, 1, 3))
return query_layer, key_layer, value_layer


def compute_attention_scores(query_layer, key_layer, alibi, head_size):
def compute_attention_scores(query_layer, key_layer, alibi):
*_, head_size = query_layer.shape
beta = 1.0
inv_norm_factor = 1.0 / math.sqrt(head_size)
matmul_result = beta * alibi + inv_norm_factor * (query_layer @ key_layer)
Expand Down Expand Up @@ -159,12 +160,12 @@ def multi_head_attention(
output_weight,
output_bias,
*,
head_size,
num_heads,
):
query_layer, key_layer, value_layer = create_query_key_value(
hidden_states, query_key_value_weight, query_key_value_bias, head_size
hidden_states, query_key_value_weight, query_key_value_bias, num_heads
)
attention_scores = compute_attention_scores(query_layer, key_layer, alibi, head_size)
attention_scores = compute_attention_scores(query_layer, key_layer, alibi)
attention_probs = compute_attention_probs(attention_scores, causal_mask)
context_layer = compute_context_layer(attention_probs, value_layer)
output_tensor = finalize_output(context_layer, output_weight, output_bias)
Expand All @@ -190,7 +191,6 @@ def mlp(
def bloom(input_ids, alibi, causal_mask, parameters, num_heads):
inputs_embeds = F.embedding(input_ids, parameters.transformer.word_embeddings.weight)
hidden_size = inputs_embeds.shape[2]
head_size = hidden_size // num_heads

hidden_states = F.layer_norm(
inputs_embeds,
Expand All @@ -215,7 +215,7 @@ def bloom(input_ids, alibi, causal_mask, parameters, num_heads):
layer_parameters.self_attention.query_key_value.bias,
layer_parameters.self_attention.dense.weight,
layer_parameters.self_attention.dense.bias,
head_size=head_size,
num_heads=num_heads,
)
attention_output += hidden_states

Expand Down
Loading

0 comments on commit cc05424

Please sign in to comment.