Skip to content

Commit

Permalink
#4003: updated ttnn_optimized_functional_bloom to work in L1
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 3, 2024
1 parent 762a15b commit 8106e12
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
pad_tensor,
)

BLOOM_MEMORY_CONFIG = ttnn.DRAM_MEMORY_CONFIG
BLOOM_MEMORY_CONFIG = ttnn.L1_MEMORY_CONFIG
BLOOM_DTYPE = ttnn.bfloat8_b
ASSUME_FUSED_SOFTMAX = False

Expand Down Expand Up @@ -96,8 +96,6 @@ def compute_attention_scores(query_layer, key_layer, alibi):
ttnn.deallocate(key_layer)

if ASSUME_FUSED_SOFTMAX:
if BLOOM_MEMORY_CONFIG == ttnn.L1_MEMORY_CONFIG:
attention_scores = ttnn.reallocate(attention_scores)
return attention_scores

inv_norm_factor = 1.0 / math.sqrt(head_size)
Expand All @@ -107,20 +105,17 @@ def compute_attention_scores(query_layer, key_layer, alibi):
scaled_attention_scores_plus_alibi = ttnn.add(scaled_attention_scores, alibi, memory_config=BLOOM_MEMORY_CONFIG)
ttnn.deallocate(scaled_attention_scores)

if BLOOM_MEMORY_CONFIG == ttnn.L1_MEMORY_CONFIG:
scaled_attention_scores_plus_alibi = ttnn.reallocate(scaled_attention_scores_plus_alibi)

return scaled_attention_scores_plus_alibi


def compute_attention_probs(attention_scores, causal_mask):
if ASSUME_FUSED_SOFTMAX:
attention_weights = attention_scores
else:
attention_weights = ttnn.add(attention_scores, causal_mask, memory_config=BLOOM_MEMORY_CONFIG)
attention_weights = ttnn.add(attention_scores, causal_mask, memory_config=ttnn.DRAM_MEMORY_CONFIG)
ttnn.deallocate(attention_scores)

attention_probs = ttnn.softmax(attention_weights, dim=-1, memory_config=BLOOM_MEMORY_CONFIG)
attention_probs = ttnn.softmax(attention_weights, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
if not ASSUME_FUSED_SOFTMAX:
ttnn.deallocate(attention_weights)

Expand Down Expand Up @@ -168,6 +163,7 @@ def multi_head_attention(
query_layer, key_layer, value_layer = create_query_key_value(
hidden_states, query_key_value_weight, query_key_value_bias, num_heads=num_heads
)
value_layer = ttnn.reallocate(value_layer)

attention_scores = compute_attention_scores(query_layer, key_layer, alibi)
attention_probs = compute_attention_probs(attention_scores, causal_mask)
Expand Down Expand Up @@ -217,15 +213,14 @@ def bloom(
layout=ttnn.TILE_LAYOUT,
)

# TODO(arakhmati): put hidden_states in L1
hidden_states = ttnn.layer_norm(
inputs_embeds,
weight=parameters.transformer.word_embeddings_layernorm.weight,
bias=parameters.transformer.word_embeddings_layernorm.bias,
memory_config=BLOOM_MEMORY_CONFIG,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
ttnn.deallocate(inputs_embeds)
if BLOOM_MEMORY_CONFIG == ttnn.L1_MEMORY_CONFIG:
hidden_states = ttnn.reallocate(hidden_states)

for layer_parameters in parameters.transformer.h:
normalized_hidden_states = ttnn.layer_norm(
Expand All @@ -247,7 +242,8 @@ def bloom(
)
ttnn.deallocate(normalized_hidden_states)

attention_output = ttnn.add(attention_output, hidden_states, memory_config=BLOOM_MEMORY_CONFIG)
# TODO(arakhmati): put attention_output in L1
attention_output = ttnn.add(attention_output, hidden_states, memory_config=ttnn.DRAM_MEMORY_CONFIG)
ttnn.deallocate(hidden_states)

normalized_attention_output = ttnn.layer_norm(
Expand All @@ -266,13 +262,12 @@ def bloom(
)
ttnn.deallocate(normalized_attention_output)

mlp_output = ttnn.add(mlp_output, attention_output, memory_config=BLOOM_MEMORY_CONFIG)
# TODO(arakhmati): put mlp_output in L1
mlp_output = ttnn.add(mlp_output, attention_output, memory_config=ttnn.DRAM_MEMORY_CONFIG)
ttnn.deallocate(attention_output)

hidden_states = mlp_output

if BLOOM_MEMORY_CONFIG == ttnn.L1_MEMORY_CONFIG:
hidden_states = ttnn.reallocate(hidden_states)
hidden_states = ttnn.reallocate(hidden_states)

hidden_states = ttnn.layer_norm(
hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def test_bloom_for_question_answering(device, use_program_cache, ttnn_model, bat
tt_end_logits = tt_output[:1, :num_tokens, 1]

if ttnn_model == ttnn_functional_bloom:
assert_with_pcc(torch_start_logits, tt_start_logits, 0.939)
assert_with_pcc(torch_end_logits, tt_end_logits, 0.911)
assert_with_pcc(torch_start_logits, tt_start_logits, 0.96677)
assert_with_pcc(torch_end_logits, tt_end_logits, 0.95177)
elif ttnn_model == ttnn_optimized_functional_bloom:
assert_with_pcc(torch_start_logits, tt_start_logits, 0.88)
assert_with_pcc(torch_end_logits, tt_end_logits, 0.88)
assert_with_pcc(torch_start_logits, tt_start_logits, 0.93999)
assert_with_pcc(torch_end_logits, tt_end_logits, 0.88868)
else:
raise RecursionError("Invalid ttnn_model")

Expand Down
2 changes: 0 additions & 2 deletions tests/ttnn/integration_tests/bloom/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_performance_of_bloom_for_question_answering(
)

# TODO: don't modify the config globally. Pass it into the functions instead
ttnn_optimized_functional_bloom.BLOOM_MEMORY_CONFIG = ttnn.L1_MEMORY_CONFIG
ttnn_optimized_functional_bloom.ASSUME_FUSED_SOFTMAX = True

durations = []
Expand Down Expand Up @@ -100,5 +99,4 @@ def test_performance_of_bloom_for_question_answering(
logger.info(f"Samples per second: {1 / inference_time * batch_size}")

# TODO: don't modify the config globally. Pass it into the functions instead
ttnn_optimized_functional_bloom.BLOOM_MEMORY_CONFIG = ttnn.DRAM_MEMORY_CONFIG
ttnn_optimized_functional_bloom.ASSUME_FUSED_SOFTMAX = False

0 comments on commit 8106e12

Please sign in to comment.