diff --git a/models/demos/metal_BERT_large_11/demo/demo.py b/models/demos/metal_BERT_large_11/demo/demo.py index 8cb85ed4fad..43f44aadcfc 100644 --- a/models/demos/metal_BERT_large_11/demo/demo.py +++ b/models/demos/metal_BERT_large_11/demo/demo.py @@ -378,11 +378,11 @@ def test_demo( return run_bert_question_and_answering_inference( model_version="phiyodr/bert-large-finetuned-squad2", - batch=8, + batch=12, seq_len=384, return_attention_mask=True, return_token_type_ids=True, - model_config=get_model_config("MIXED_PRECISION_BATCH8"), + model_config=get_model_config("BFLOAT8_B-SHARDED_BATCH12"), tt_cache_path=get_tt_cache_path("phiyodr/bert-large-finetuned-squad2"), NUM_RUNS=NUM_RUNS, input_path=input_path, diff --git a/models/demos/metal_BERT_large_11/demo/input_data.json b/models/demos/metal_BERT_large_11/demo/input_data.json index 98877ac2488..950b8d36323 100644 --- a/models/demos/metal_BERT_large_11/demo/input_data.json +++ b/models/demos/metal_BERT_large_11/demo/input_data.json @@ -30,5 +30,21 @@ { "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.", "question" : "What is the main component of the aboral organ?" + }, + { + "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.", + "question": "what is Rothko best known for?" + }, + { + "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.", + "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?" + }, + { + "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.", + "question": "When was the Hashemite Arab Federation formed?" + }, + { + "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.", + "question": "What allows Tasmanian devil to generate strong bites?" } ] diff --git a/models/demos/metal_BERT_large_11/tests/test_demo.py b/models/demos/metal_BERT_large_11/tests/test_demo.py index 6977c810ad2..077fae5ca51 100644 --- a/models/demos/metal_BERT_large_11/tests/test_demo.py +++ b/models/demos/metal_BERT_large_11/tests/test_demo.py @@ -20,13 +20,17 @@ def test_demo(input_path, model_location_generator, device, use_program_cache): expected_answers = { 0: "scientific archaeology", - 1: "Richard I of Normandy", + 1: "Richard I", 2: "males", - 3: "The Huguenots adapted quickly and often married outside their immediate French communities,", + 3: "married outside their immediate French communities,", 4: "biostratigraphers", 5: "chemotaxis,", 6: "1992,", 7: "statocyst,", + 8: "color field paintings", + 9: "paranoia", + 10: "six months earlier.", + 11: "large head and neck", } NUM_RUNS = 1000 measurements, answers = demo_json(input_path, NUM_RUNS, model_location_generator, device, use_program_cache) diff --git a/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py b/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py index 77d6cc898dc..a2fb59ddcb8 100644 --- a/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py +++ b/models/demos/metal_BERT_large_11/tests/test_perf_bert11.py @@ -150,7 +150,7 @@ def run_perf_bert11( @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize( "expected_inference_time, expected_compile_time, inference_iterations", - ([0.07, 14.5, 10],), + ([0.05, 14.5, 10],), ) def test_perf_virtual_machine( use_program_cache, @@ -171,7 +171,7 @@ def test_perf_virtual_machine( @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "expected_inference_time, expected_compile_time, inference_iterations", - ([0.04, 5, 10],), + ([0.0375, 10, 10],), ) def test_perf_bare_metal( use_program_cache, diff --git a/models/demos/metal_BERT_large_11/tests/test_perf_device_bert.py b/models/demos/metal_BERT_large_11/tests/test_perf_device_bert.py index d24ee2451e2..635eddb7983 100644 --- a/models/demos/metal_BERT_large_11/tests/test_perf_device_bert.py +++ b/models/demos/metal_BERT_large_11/tests/test_perf_device_bert.py @@ -19,8 +19,8 @@ "batch_size, test, expected_perf", [ # [9, "BERT_LARGE-batch_9-MIXED_PRECISION_BATCH9", 70], - [8, "BERT_LARGE-batch_8-MIXED_PRECISION_BATCH8", 160], - [12, "BERT_LARGE-batch_12-BFLOAT8_B-SHARDED_BATCH12", 160], + [8, "BERT_LARGE-batch_8-MIXED_PRECISION_BATCH8", 165], + [12, "BERT_LARGE-batch_12-BFLOAT8_B-SHARDED_BATCH12", 390], ], ) def test_perf_device_bare_metal(batch_size, test, expected_perf): diff --git a/models/demos/metal_BERT_large_11/tt/bert_encoder.py b/models/demos/metal_BERT_large_11/tt/bert_encoder.py index 80c5cf6c955..7cf081c654b 100644 --- a/models/demos/metal_BERT_large_11/tt/bert_encoder.py +++ b/models/demos/metal_BERT_large_11/tt/bert_encoder.py @@ -139,66 +139,59 @@ def __init__(self, config, encoder_idx, state_dict, device, model_config, tt_cac self.layer_norm_eps = config.layer_norm_eps if "OP7_SELFOUT_CONFIG" in model_config: - self.selfout_matmul = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP7_SELFOUT_CONFIG"], - output_mem_config=model_config["OP7_SELFOUT_OUTPUT_MEMCFG"], - output_dtype=model_config["OP7_SELFOUT_OUTPUT_DTYPE"], - ) - else: - self.selfout_matmul = partial( - tt_lib.tensor.bert_large_selfout_matmul, - output_mem_config=model_config["OP7_SELFOUT_OUTPUT_MEMCFG"], - output_dtype=model_config["OP7_SELFOUT_OUTPUT_DTYPE"], - ) - if "OP8_LAYERNORM_CONFIG" in model_config: - self.mha_layernorm = partial( - tt_lib.operations.primary.add_layernorm, - program_config=model_config["OP8_LAYERNORM_CONFIG"], - output_mem_config=model_config["OP8_LAYERNORM_OUTPUT_MEMCFG"], - ) - else: - self.mha_layernorm = partial( - tt_lib.operations.primary.add_layernorm, - output_mem_config=model_config["OP8_LAYERNORM_OUTPUT_MEMCFG"], - ) - if "OP11_LAYERNORM_CONFIG" in model_config: - self.ffn_layernorm = partial( - tt_lib.operations.primary.add_layernorm, - program_config=model_config["OP11_LAYERNORM_CONFIG"], - output_mem_config=model_config["OP11_LAYERNORM_OUTPUT_MEMCFG"], - ) + + def op7_mm_plus_bias(mha_res, attention_output_weight, attention_output_bias): + mha_out = tt_lib.operations.primary.matmul( + mha_res, + attention_output_weight, + bias=attention_output_bias, + program_config=model_config["OP7_SELFOUT_CONFIG"], + output_mem_config=model_config["OP7_SELFOUT_OUTPUT_MEMCFG"], + output_dtype=model_config["OP7_SELFOUT_OUTPUT_DTYPE"], + ) + return mha_out + else: - self.ffn_layernorm = partial( - tt_lib.operations.primary.add_layernorm, - output_mem_config=model_config["OP11_LAYERNORM_OUTPUT_MEMCFG"], - ) - def op7_mm_plus_bias(self, mha_res, attention_output_weight, attention_output_bias): - mha_out = self.selfout_matmul( - mha_res, - attention_output_weight, - bias=attention_output_bias, + def op7_mm_plus_bias(mha_res, attention_output_weight, attention_output_bias): + mha_out = tt_lib.tensor.bert_large_selfout_matmul( + mha_res, + attention_output_weight, + bias=attention_output_bias, + output_mem_config=model_config["OP7_SELFOUT_OUTPUT_MEMCFG"], + output_dtype=model_config["OP7_SELFOUT_OUTPUT_DTYPE"], + ) + return mha_out + + self.op7_mm_plus_bias = op7_mm_plus_bias + self.mha_ln_program_config = model_config.get( + "OP8_LAYERNORM_CONFIG", tt_lib.operations.primary.LayerNormDefaultProgramConfig() + ) + self.ffn_ln_program_config = model_config.get( + "OP11_LAYERNORM_CONFIG", tt_lib.operations.primary.LayerNormDefaultProgramConfig() ) - return mha_out def op8_add_layernorm(self, activation, mha_out): - mha_out_add_and_norm = self.mha_layernorm( + mha_out_add_and_norm = tt_lib.operations.primary.add_layernorm( activation, mha_out, self.layer_norm_eps, self.mha_gamma, self.mha_beta, + program_config=self.mha_ln_program_config, + output_mem_config=self.model_config["OP8_LAYERNORM_OUTPUT_MEMCFG"], ) return mha_out_add_and_norm def op11_add_layernorm(self, mha_out_add_and_norm, ffn_out): - ffn_out_add_and_norm = self.ffn_layernorm( + ffn_out_add_and_norm = tt_lib.operations.primary.add_layernorm( mha_out_add_and_norm, ffn_out, self.layer_norm_eps, self.ffn_gamma, self.ffn_beta, + program_config=self.ffn_ln_program_config, + output_mem_config=self.model_config["OP11_LAYERNORM_OUTPUT_MEMCFG"], ) return ffn_out_add_and_norm diff --git a/models/demos/metal_BERT_large_11/tt/ffn.py b/models/demos/metal_BERT_large_11/tt/ffn.py index 599780fb103..b25435fcc3f 100644 --- a/models/demos/metal_BERT_large_11/tt/ffn.py +++ b/models/demos/metal_BERT_large_11/tt/ffn.py @@ -4,7 +4,6 @@ import torch -from functools import partial import tt_lib from tt_lib.utils import pad_weight @@ -25,51 +24,54 @@ def feed_forward( # ff1_weighta = [1, 1, 1024, 4096] # output = [1, 9, 384, 4096] if "OP9_FF1_MM_CONFIG" in model_config: - ff1_matmul = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP9_FF1_MM_CONFIG"], - output_mem_config=model_config["OP9_FF1_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP9_FF1_MM_OUTPUT_DTYPE"], - ) + + def op9_MM_bias_gelu(activation, ff1_weighta, ff1_biasa): + output_plus_bias_act = tt_lib.operations.primary.matmul( + activation, + ff1_weighta, + bias=ff1_biasa, + program_config=model_config["OP9_FF1_MM_CONFIG"], + output_mem_config=model_config["OP9_FF1_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP9_FF1_MM_OUTPUT_DTYPE"], + ) + return output_plus_bias_act + else: - ff1_matmul = partial( - tt_lib.tensor.bert_large_ff1_matmul, - fused_activation=(tt_lib.tensor.FusibleActivation.GELU, True), - output_mem_config=model_config["OP9_FF1_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP9_FF1_MM_OUTPUT_DTYPE"], - ) + + def op9_MM_bias_gelu(activation, ff1_weighta, ff1_biasa): + output_plus_bias_act = tt_lib.tensor.bert_large_ff1_matmul( + activation, + ff1_weighta, + bias=ff1_biasa, + output_mem_config=model_config["OP9_FF1_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP9_FF1_MM_OUTPUT_DTYPE"], + ) + return output_plus_bias_act + if "OP10_FF2_MM_CONFIG" in model_config: - ff2_matmul = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP10_FF2_MM_CONFIG"], - output_mem_config=model_config["OP10_FF2_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP10_FF2_MM_OUTPUT_DTYPE"], - ) + + def op10_MM_bias(activation, ff2_weighta, ff2_biasa): + output_plus_bias = tt_lib.operations.primary.matmul( + activation, + ff2_weighta, + bias=ff2_biasa, + program_config=model_config["OP10_FF2_MM_CONFIG"], + output_mem_config=model_config["OP10_FF2_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP10_FF2_MM_OUTPUT_DTYPE"], + ) + return output_plus_bias + else: - ff2_matmul = partial( - tt_lib.tensor.bert_large_ff2_matmul, - output_mem_config=model_config["OP10_FF2_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP10_FF2_MM_OUTPUT_DTYPE"], - ) - def op9_MM_bias_gelu(activation, ff1_weighta, ff1_biasa): - output_plus_bias_act = ff1_matmul( - activation, - ff1_weighta, - bias=ff1_biasa, - ) - return output_plus_bias_act - - # activation = [1, 9, 384, 4096] - # ff2_weighta = [1, 1, 4096, 1024] - # output = [1, 9, 384, 1024] - def op10_MM_bias(activation, ff2_weighta, ff2_biasa): - output_plus_bias = ff2_matmul( - activation, - ff2_weighta, - bias=ff2_biasa, - ) - return output_plus_bias + def op10_MM_bias(activation, ff2_weighta, ff2_biasa): + output_plus_bias = tt_lib.tensor.bert_large_ff2_matmul( + activation, + ff2_weighta, + bias=ff2_biasa, + output_mem_config=model_config["OP10_FF2_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP10_FF2_MM_OUTPUT_DTYPE"], + ) + return output_plus_bias def feed_forward_(activation: tt_lib.tensor.Tensor) -> tt_lib.tensor.Tensor: ff1_output_plus_bias_act = op9_MM_bias_gelu(activation, ff1_weighta, ff1_biasa) diff --git a/models/demos/metal_BERT_large_11/tt/mha.py b/models/demos/metal_BERT_large_11/tt/mha.py index 2c50a3f6c15..3f8e95465da 100644 --- a/models/demos/metal_BERT_large_11/tt/mha.py +++ b/models/demos/metal_BERT_large_11/tt/mha.py @@ -10,7 +10,6 @@ import tt_lib from tt_lib.utils import pad_weight from models.utility_functions import torch2tt_tensor -from functools import partial def mha(qkv_weight, qkv_bias, hidden_dim, num_heads, device, model_config): @@ -22,83 +21,70 @@ def mha(qkv_weight, qkv_bias, hidden_dim, num_heads, device, model_config): reserve_split_heads_shape = model_config.get("RESERVE_SPLIT_HEADS_SHAPE", None) if "OP1_FUSED_QKV_MM_CONFIG" in model_config: - qkv_matmul = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP1_FUSED_QKV_MM_CONFIG"], - output_mem_config=model_config["OP1_FUSED_QKV_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP1_FUSED_QKV_MM_OUTPUT_DTYPE"], - ) - else: - qkv_matmul = partial( - tt_lib.tensor.bert_large_fused_qkv_matmul, - output_mem_config=model_config["OP1_FUSED_QKV_MM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP1_FUSED_QKV_MM_OUTPUT_DTYPE"], - ) - if "OP3_PRE_SOFTMAX_BMM_CONFIG" in model_config: - pre_softmax_bmm = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP3_PRE_SOFTMAX_BMM_CONFIG"], - output_mem_config=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_DTYPE"], - ) - else: - pre_softmax_bmm = partial( - tt_lib.tensor.bert_large_pre_softmax_bmm, - output_mem_config=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_DTYPE"], - ) - if "OP5_POST_SOFTMAX_BMM_CONFIG" in model_config: - post_softmax_bmm = partial( - tt_lib.operations.primary.matmul, - program_config=model_config["OP5_POST_SOFTMAX_BMM_CONFIG"], - output_mem_config=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_DTYPE"], - ) - else: - post_softmax_bmm = partial( - tt_lib.tensor.bert_large_post_softmax_bmm, - output_mem_config=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_MEMCFG"], - output_dtype=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_DTYPE"], - ) - if "OP4_SOFTMAX_CONFIG" in model_config: - softmax = partial( - tt_lib.operations.primary.transformers.scale_mask_softmax_in_place, - program_config=model_config["OP4_SOFTMAX_CONFIG"], - ) + def op1_qkv_fused(activation, qkv_weight, qkv_bias): + qkv = tt_lib.operations.primary.matmul( + activation, + qkv_weight, + bias=qkv_bias, + program_config=model_config["OP1_FUSED_QKV_MM_CONFIG"], + output_mem_config=model_config["OP1_FUSED_QKV_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP1_FUSED_QKV_MM_OUTPUT_DTYPE"], + ) + return qkv + else: - softmax = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place - split_fused_qkv_and_split_heads = partial( - tt_lib.operations.primary.transformers.split_fused_qkv_and_split_heads, - compute_with_storage_grid_size=model_config.get("GRID_SIZE", device.compute_with_storage_grid_size()), - output_mem_config=model_config["OP2_SPLIT_QKV_HEADS_OUTPUT_MEMCFG"], - ) + def op1_qkv_fused(activation, qkv_weight, qkv_bias): + qkv = tt_lib.tensor.bert_large_fused_qkv_matmul( + activation, + qkv_weight, + bias=qkv_bias, + output_mem_config=model_config["OP1_FUSED_QKV_MM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP1_FUSED_QKV_MM_OUTPUT_DTYPE"], + ) + return qkv - def op1_qkv_fused(activation, qkv_weight, qkv_bias): - qkv = qkv_matmul( - activation, - qkv_weight, - bias=qkv_bias, - ) - return qkv + grid_size = model_config.get("GRID_SIZE", device.compute_with_storage_grid_size()) def op2_create_qkv_heads(qkv): ( q_heads, kt_heads, v_heads, - ) = split_fused_qkv_and_split_heads( + ) = tt_lib.operations.primary.transformers.split_fused_qkv_and_split_heads( qkv, + compute_with_storage_grid_size=grid_size, + output_mem_config=model_config["OP2_SPLIT_QKV_HEADS_OUTPUT_MEMCFG"], ) return q_heads, kt_heads, v_heads - def op3_bmm(Q_heads, K_T_heads): - qkt = pre_softmax_bmm( - Q_heads, - K_T_heads, - ) - return qkt + if "OP3_PRE_SOFTMAX_BMM_CONFIG" in model_config: + + def op3_bmm(Q_heads, K_T_heads): + qkt = tt_lib.operations.primary.matmul( + Q_heads, + K_T_heads, + program_config=model_config["OP3_PRE_SOFTMAX_BMM_CONFIG"], + output_mem_config=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_DTYPE"], + ) + return qkt + + else: + + def op3_bmm(Q_heads, K_T_heads): + qkt = tt_lib.tensor.bert_large_pre_softmax_bmm( + Q_heads, + K_T_heads, + output_mem_config=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP3_PRE_SOFTMAX_BMM_OUTPUT_DTYPE"], + ) + return qkt + + softmax_program_config = model_config.get( + "OP4_SOFTMAX_CONFIG", tt_lib.operations.primary.transformers.SoftmaxDefaultProgramConfig() + ) def op4_scale_mask_softmax(qkt, attention_mask): # Attention scores computation @@ -107,7 +93,9 @@ def op4_scale_mask_softmax(qkt, attention_mask): # No-op reshapes are handled within pre-softmax (op 7) and post-softmax bmms (op 9) shape = qkt.shape() qkt = qkt.reshape(shape[0], 1, shape[1] * shape[2], shape[3]) - attention_scores = softmax(qkt, freciprocal_of_sqrt_hidden_dim, attention_mask) + attention_scores = tt_lib.operations.primary.transformers.scale_mask_softmax_in_place( + qkt, freciprocal_of_sqrt_hidden_dim, attention_mask, program_config=softmax_program_config + ) attention_scores = attention_scores.reshape(shape) return attention_scores @@ -120,6 +108,31 @@ def op5_bmm(attention_scores, V_heads): return weighted_activation + if "OP5_POST_SOFTMAX_BMM_CONFIG" in model_config: + + def op5_bmm(attention_scores, V_heads): + weighted_activation = tt_lib.operations.primary.matmul( + attention_scores, + V_heads, + program_config=model_config["OP5_POST_SOFTMAX_BMM_CONFIG"], + output_mem_config=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_DTYPE"], + ) + + return weighted_activation + + else: + + def op5_bmm(attention_scores, V_heads): + weighted_activation = tt_lib.tensor.bert_large_post_softmax_bmm( + attention_scores, + V_heads, + output_mem_config=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_MEMCFG"], + output_dtype=model_config["OP5_POST_SOFTMAX_BMM_OUTPUT_DTYPE"], + ) + + return weighted_activation + def op6_unmake_attention_heads(x): if num_heads == 1: return x