From 0d96801608bb4ba994dada43ab4a570a1ba5b8f4 Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Fri, 12 Jan 2024 20:24:59 +0000 Subject: [PATCH] #4943: add groupnorm support, supports height/block shard, rm and tile layout --- .../unit_testing/test_bert_ops.py | 733 ++++++++-------- .../unit_testing/test_groupnorm_sharded.py | 208 +++++ .../op_library/groupnorm/groupnorm_op.cpp | 804 +++++++++++++++++- .../op_library/groupnorm/groupnorm_op.hpp | 41 + .../kernels/compute/groupnorm_sharded.cpp | 461 ++++++++++ ...reader_mcast_receiver_unary_sharded_gn.cpp | 41 + .../reader_mcast_sender_unary_sharded_gn.cpp | 195 +++++ .../dataflow/writer_unary_sharded_gn.cpp | 109 +++ .../writer_unary_sharded_gn_rm_gb.cpp | 133 +++ .../tt_lib/csrc/operations/primary/module.hpp | 27 + 10 files changed, 2419 insertions(+), 333 deletions(-) create mode 100644 tests/tt_eager/python_api_testing/unit_testing/test_groupnorm_sharded.py create mode 100644 tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp create mode 100644 tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_gn.cpp create mode 100644 tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_sender_unary_sharded_gn.cpp create mode 100644 tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn.cpp create mode 100644 tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn_rm_gb.cpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_bert_ops.py b/tests/tt_eager/python_api_testing/unit_testing/test_bert_ops.py index f540532f813..141ab8b2737 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_bert_ops.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_bert_ops.py @@ -20,304 +20,422 @@ from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero -@pytest.mark.skipif(is_wormhole_b0(), reason="Unsupported parallelizations for WH B0") -@pytest.mark.parametrize( - "fidelity", [ttl.tensor.MathFidelity.LoFi, ttl.tensor.MathFidelity.HiFi2], ids=["LoFi", "HiFi2"] -) -@pytest.mark.parametrize("has_bias", [True, False], ids=["bias", "no_bias"]) -@pytest.mark.parametrize( - "in1_in_dram, out_sharded, in0_sharded, M, K, N, activation", - [ - # (False, True, True, 12*128, 1024, 1024, None), - # (False, True, True, 12*128, 4096, 1024, None), - # (False, True, True, 12*128, 8192, 1024, None), - # one core - # (False, False, False, 128, 256, 128, None), - # # in1-L1-fusedQKV - (False, True, True, 4608, 1024, 3072, None), # both sharded - (False, True, False, 4608, 1024, 3072, None), # out sharded, in0 interleaved - (False, False, True, 4608, 1024, 3072, None), # out interleaved, in0 sharded - (False, False, False, 4608, 1024, 3072, None), # out interleaved, in0 interleaved - # # # in1-dram-fusedQKV - (True, True, True, 4608, 1024, 3072, None), - (True, True, False, 4608, 1024, 3072, None), - (True, False, True, 4608, 1024, 3072, None), - (True, False, False, 4608, 1024, 3072, None), - # # # in1-L1-selfout - (False, True, True, 4608, 1024, 1024, None), - (False, True, False, 4608, 1024, 1024, None), - (False, False, True, 4608, 1024, 1024, None), - (False, False, False, 4608, 1024, 1024, None), - # # # in1-dram-selfout - (True, True, True, 4608, 1024, 1024, None), - (True, True, False, 4608, 1024, 1024, None), - (True, False, True, 4608, 1024, 1024, None), - (True, False, False, 4608, 1024, 1024, None), - # # # in1-L1-ff1 - (False, True, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, True, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, False, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, False, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - # # # in1-dram-ff1 - (True, True, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, True, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, False, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, False, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - # # # in1-L1-ff1 - no Gelu - (False, True, True, 4608, 1024, 4096, None), - (False, True, False, 4608, 1024, 4096, None), - (False, False, True, 4608, 1024, 4096, None), - (False, False, False, 4608, 1024, 4096, None), - # # # in1-dram-ff1 - no Gelu - (True, True, True, 4608, 1024, 4096, None), - (True, True, False, 4608, 1024, 4096, None), - (True, False, True, 4608, 1024, 4096, None), - (True, False, False, 4608, 1024, 4096, None), - # # # in1-L1-ff2 - (False, True, True, 4608, 4096, 1024, None), - (False, True, False, 4608, 4096, 1024, None), - (False, False, True, 4608, 4096, 1024, None), - (False, False, False, 4608, 4096, 1024, None), - # # # in1-dram-ff2 - (True, True, True, 4608, 4096, 1024, None), - (True, True, False, 4608, 4096, 1024, None), - (True, False, True, 4608, 4096, 1024, None), - (True, False, False, 4608, 4096, 1024, None), - ], -) -def test_bert_linear( - device, fidelity, in0_sharded, out_sharded, in1_in_dram, has_bias, M, K, N, activation, function_level_defaults -): - in0_shape = [1, 1, M, K] - in1_shape = [1, 1, K, N] - bias_shape = [1, 1, N] - grid_size = (12, 8) - # grid_size = (2, 2) - shard_shape = [M // grid_size[0], K // grid_size[1]] # shard height, width - - in0_block_w = K // grid_size[1] // 32 # 16 - in0_block_h = M // grid_size[0] // 32 - out_block_h = M // grid_size[0] // 32 - out_block_w = N // grid_size[1] // 32 - - if out_block_w <= 8: - out_subblock_w = out_block_w - out_subblock_h = 8 // out_subblock_w - else: - out_subblock_h = 1 - out_subblock_w = 8 // out_subblock_h - while out_block_w % out_subblock_w != 0: - out_subblock_w = out_block_w // 2 - - # in0_block_w = K // grid_size[1] // 32 - # out_subblock_w = 4 - # out_subblock_h = 4 +# @pytest.mark.skipif(is_wormhole_b0(), reason="Unsupported parallelizations for WH B0") +# @pytest.mark.parametrize( +# "fidelity", [ttl.tensor.MathFidelity.LoFi, ttl.tensor.MathFidelity.HiFi2], ids=["LoFi", "HiFi2"] +# ) +# @pytest.mark.parametrize("has_bias", [True, False], ids=["bias", "no_bias"]) +# @pytest.mark.parametrize( +# "in1_in_dram, out_sharded, in0_sharded, M, K, N, activation", +# [ +# # (False, True, True, 12*128, 1024, 1024, None), +# # (False, True, True, 12*128, 4096, 1024, None), +# # (False, True, True, 12*128, 8192, 1024, None), +# # one core +# # (False, False, False, 128, 256, 128, None), +# # # in1-L1-fusedQKV +# (False, True, True, 4608, 1024, 3072, None), # both sharded +# (False, True, False, 4608, 1024, 3072, None), # out sharded, in0 interleaved +# (False, False, True, 4608, 1024, 3072, None), # out interleaved, in0 sharded +# (False, False, False, 4608, 1024, 3072, None), # out interleaved, in0 interleaved +# # # # in1-dram-fusedQKV +# (True, True, True, 4608, 1024, 3072, None), +# (True, True, False, 4608, 1024, 3072, None), +# (True, False, True, 4608, 1024, 3072, None), +# (True, False, False, 4608, 1024, 3072, None), +# # # # in1-L1-selfout +# (False, True, True, 4608, 1024, 1024, None), +# (False, True, False, 4608, 1024, 1024, None), +# (False, False, True, 4608, 1024, 1024, None), +# (False, False, False, 4608, 1024, 1024, None), +# # # # in1-dram-selfout +# (True, True, True, 4608, 1024, 1024, None), +# (True, True, False, 4608, 1024, 1024, None), +# (True, False, True, 4608, 1024, 1024, None), +# (True, False, False, 4608, 1024, 1024, None), +# # # # in1-L1-ff1 +# (False, True, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, True, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, False, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, False, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# # # # in1-dram-ff1 +# (True, True, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, True, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, False, True, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, False, False, 4608, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# # # # in1-L1-ff1 - no Gelu +# (False, True, True, 4608, 1024, 4096, None), +# (False, True, False, 4608, 1024, 4096, None), +# (False, False, True, 4608, 1024, 4096, None), +# (False, False, False, 4608, 1024, 4096, None), +# # # # in1-dram-ff1 - no Gelu +# (True, True, True, 4608, 1024, 4096, None), +# (True, True, False, 4608, 1024, 4096, None), +# (True, False, True, 4608, 1024, 4096, None), +# (True, False, False, 4608, 1024, 4096, None), +# # # # in1-L1-ff2 +# (False, True, True, 4608, 4096, 1024, None), +# (False, True, False, 4608, 4096, 1024, None), +# (False, False, True, 4608, 4096, 1024, None), +# (False, False, False, 4608, 4096, 1024, None), +# # # # in1-dram-ff2 +# (True, True, True, 4608, 4096, 1024, None), +# (True, True, False, 4608, 4096, 1024, None), +# (True, False, True, 4608, 4096, 1024, None), +# (True, False, False, 4608, 4096, 1024, None), +# ], +# ) +# def test_bert_linear( +# device, fidelity, in0_sharded, out_sharded, in1_in_dram, has_bias, M, K, N, activation, function_level_defaults +# ): +# in0_shape = [1, 1, M, K] +# in1_shape = [1, 1, K, N] +# bias_shape = [1, 1, N] +# grid_size = (12, 8) +# # grid_size = (2, 2) +# shard_shape = [M // grid_size[0], K // grid_size[1]] # shard height, width + +# in0_block_w = K // grid_size[1] // 32 # 16 +# in0_block_h = M // grid_size[0] // 32 +# out_block_h = M // grid_size[0] // 32 +# out_block_w = N // grid_size[1] // 32 + +# if out_block_w <= 8: +# out_subblock_w = out_block_w +# out_subblock_h = 8 // out_subblock_w +# else: +# out_subblock_h = 1 +# out_subblock_w = 8 // out_subblock_h +# while out_block_w % out_subblock_w != 0: +# out_subblock_w = out_block_w // 2 + +# # in0_block_w = K // grid_size[1] // 32 +# # out_subblock_w = 4 +# # out_subblock_h = 4 + +# logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) +# logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) +# logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) +# logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) + +# interleaved_mem_config_L1 = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, +# buffer_type=ttl.tensor.BufferType.L1, +# ) +# interleaved_mem_config_DRAM = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, +# buffer_type=ttl.tensor.BufferType.DRAM, +# ) +# sharded_mem_config = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, +# buffer_type=ttl.tensor.BufferType.L1, +# ) + +# in0 = torch.randn(in0_shape).bfloat16().float() +# in1 = torch.randn(in1_shape).bfloat16().float() +# bias = torch.randn(bias_shape).bfloat16().float() + +# if in0_sharded: +# in0_t = torch2tt_tensor( +# in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) +# else: +# in0_t = torch2tt_tensor( +# in0, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) + +# if in1_in_dram: +# in1_t = torch2tt_tensor( +# in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) +# else: +# in1_t = torch2tt_tensor( +# in1, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) + +# output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 +# bias_t = pad_by_zero( +# bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# )[0] + +# if in0_sharded: +# in0_t = ttl.tensor.interleaved_to_sharded( +# in0_t, +# grid_size, +# [M // grid_size[0], K // grid_size[1]], +# ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, +# ttl.tensor.ShardOrientation.COL_MAJOR, +# ) + +# program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( +# compute_with_storage_grid_size=grid_size, +# in0_block_w=in0_block_w, +# out_subblock_h=out_subblock_h, +# out_subblock_w=out_subblock_w, +# per_core_M=out_block_h, +# per_core_N=out_block_w, +# transpose_mcast=True, +# # transpose_mcast=False, +# fused_activation=activation, +# ) + +# if has_bias: +# output_t = ttl.operations.primary.matmul( +# in0_t, +# in1_t, +# bias=bias_t, +# program_config=program_config, +# output_mem_config=output_mem_config, +# math_fidelity=fidelity, +# ) +# else: +# output_t = ttl.operations.primary.matmul( +# in0_t, +# in1_t, +# program_config=program_config, +# output_mem_config=output_mem_config, +# math_fidelity=fidelity, +# ) + +# if out_sharded: +# output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) + +# pt_out = in0 @ in1 + +# if has_bias: +# pt_out = pt_out + bias + +# if activation != None: +# pt_out = torch.nn.functional.gelu(pt_out) +# tt_out = tt2torch_tensor(output_t) + +# passing, output = comp_pcc(pt_out, tt_out) +# logger.info(output) +# assert passing + + +# @pytest.mark.skipif(is_grayskull(), reason="not tested for GS") +# @pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +# @pytest.mark.parametrize("fp32_acc_mode", [True, False], ids=["fp32", "no_fp32"]) +# @pytest.mark.parametrize( +# "fidelity", +# [ +# ttl.tensor.MathFidelity.LoFi, +# ], +# ids=["LoFi"], +# ) +# @pytest.mark.parametrize("has_bias", [True, False], ids=["bias", "no_bias"]) +# @pytest.mark.parametrize( +# "in1_in_dram, out_sharded, in0_sharded, M, K, N, activation", +# [ +# # # in1-L1-fusedQKV +# (False, True, True, 2688, 1024, 3072, None), # both sharded +# (False, True, False, 2688, 1024, 3072, None), # out sharded, in0 interleaved +# (False, False, True, 2688, 1024, 3072, None), # out interleaved, in0 sharded +# (False, False, False, 2688, 1024, 3072, None), # out interleaved, in0 interleaved +# # # # # in1-dram-fusedQKV +# (True, True, True, 2688, 1024, 3072, None), +# (True, True, False, 2688, 1024, 3072, None), +# (True, False, True, 2688, 1024, 3072, None), +# (True, False, False, 2688, 1024, 3072, None), +# # # # # in1-L1-selfout +# (False, True, True, 2688, 1024, 1024, None), +# (False, True, False, 2688, 1024, 1024, None), +# (False, False, True, 2688, 1024, 1024, None), +# (False, False, False, 2688, 1024, 1024, None), +# # # # # in1-dram-selfout +# (True, True, True, 2688, 1024, 1024, None), +# (True, True, False, 2688, 1024, 1024, None), +# (True, False, True, 2688, 1024, 1024, None), +# (True, False, False, 2688, 1024, 1024, None), +# # # # # in1-L1-ff1 +# (False, True, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, True, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, False, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (False, False, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# # # # # in1-dram-ff1 +# (True, True, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, True, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, False, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# (True, False, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), +# # # # # # in1-L1-ff1 - no Gelu +# (False, True, True, 2688, 1024, 4096, None), +# (False, True, False, 2688, 1024, 4096, None), +# (False, False, True, 2688, 1024, 4096, None), +# (False, False, False, 2688, 1024, 4096, None), +# # # # # in1-dram-ff1 - no Gelu +# (True, True, True, 2688, 1024, 4096, None), +# (True, True, False, 2688, 1024, 4096, None), +# (True, False, True, 2688, 1024, 4096, None), +# (True, False, False, 2688, 1024, 4096, None), +# # # # # in1-L1-ff2 +# (False, True, True, 2688, 4096, 1024, None), +# (False, True, False, 2688, 4096, 1024, None), +# (False, False, True, 2688, 4096, 1024, None), +# (False, False, False, 2688, 4096, 1024, None), +# # # # # in1-dram-ff2 +# (True, True, True, 2688, 4096, 1024, None), +# (True, True, False, 2688, 4096, 1024, None), +# (True, False, True, 2688, 4096, 1024, None), +# (True, False, False, 2688, 4096, 1024, None), +# ], +# ) +# @skip_for_wormhole_b0("WH ND hang, see issue #4392") +# def test_bert_linear_batch7( +# device, +# fidelity, +# in0_sharded, +# out_sharded, +# in1_in_dram, +# has_bias, +# fp32_acc_mode, +# packer_l1_acc, +# M, +# K, +# N, +# activation, +# function_level_defaults, +# ): +# in0_shape = [1, 1, M, K] +# in1_shape = [1, 1, K, N] +# bias_shape = [1, 1, N] +# grid_size = (8, 7) + +# in0_block_h = M // grid_size[1] // 32 +# in0_block_w = K // grid_size[0] // 32 +# out_block_h = M // grid_size[1] // 32 +# out_block_w = N // grid_size[0] // 32 + +# if fp32_acc_mode == True: +# out_subblock_w = 4 +# out_subblock_h = 1 +# else: +# if out_block_w <= 8: +# out_subblock_w = out_block_w +# out_subblock_h = 8 // out_subblock_w +# else: +# out_subblock_h = 1 +# out_subblock_w = 8 // out_subblock_h +# while out_block_w % out_subblock_w != 0: +# out_subblock_w = out_block_w // 2 + +# logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) +# logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) +# logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) +# logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) + +# interleaved_mem_config_L1 = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, +# buffer_type=ttl.tensor.BufferType.L1, +# ) +# interleaved_mem_config_DRAM = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, +# buffer_type=ttl.tensor.BufferType.DRAM, +# ) +# sharded_mem_config = ttl.tensor.MemoryConfig( +# memory_layout=ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, +# buffer_type=ttl.tensor.BufferType.L1, +# ) + +# in0 = torch.randn(in0_shape).bfloat16().float() +# in1 = torch.randn(in1_shape).bfloat16().float() +# bias = torch.randn(bias_shape).bfloat16().float() + +# in0_t = torch2tt_tensor( +# in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) +# in1_t = torch2tt_tensor( +# in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# ) + +# output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 +# bias_t = pad_by_zero( +# bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B +# )[0] + +# if in0_sharded: +# in0_t = ttl.tensor.interleaved_to_sharded( +# in0_t, +# grid_size, +# [M // grid_size[1], K // grid_size[0]], +# ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, +# ttl.tensor.ShardOrientation.ROW_MAJOR, +# ) + +# program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( +# compute_with_storage_grid_size=grid_size, +# in0_block_w=in0_block_w, +# out_subblock_h=out_subblock_h, +# out_subblock_w=out_subblock_w, +# per_core_M=out_block_h, +# per_core_N=out_block_w, +# transpose_mcast=False, +# fused_activation=activation, +# ) + +# if has_bias: +# output_t = ttl.operations.primary.matmul( +# in0_t, +# in1_t, +# bias=bias_t, +# program_config=program_config, +# output_mem_config=output_mem_config, +# math_fidelity=fidelity, +# fp32_dest_acc_en=fp32_acc_mode, +# packer_l1_acc=packer_l1_acc, +# ) +# else: +# output_t = ttl.operations.primary.matmul( +# in0_t, +# in1_t, +# program_config=program_config, +# output_mem_config=output_mem_config, +# math_fidelity=fidelity, +# fp32_dest_acc_en=fp32_acc_mode, +# packer_l1_acc=packer_l1_acc, +# ) + +# if out_sharded: +# output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) + +# pt_out = in0 @ in1 + +# if has_bias: +# pt_out = pt_out + bias + +# if activation != None: +# pt_out = torch.nn.functional.gelu(pt_out) +# tt_out = tt2torch_tensor(output_t) + +# passing, output = comp_pcc(pt_out, tt_out) +# logger.info(output) +# assert passing - logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) - logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) - logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) - logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) - interleaved_mem_config_L1 = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttl.tensor.BufferType.L1, - ) - interleaved_mem_config_DRAM = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttl.tensor.BufferType.DRAM, - ) - sharded_mem_config = ttl.tensor.MemoryConfig( - memory_layout=ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - buffer_type=ttl.tensor.BufferType.L1, - ) - - in0 = torch.randn(in0_shape).bfloat16().float() - in1 = torch.randn(in1_shape).bfloat16().float() - bias = torch.randn(bias_shape).bfloat16().float() - - if in0_sharded: - in0_t = torch2tt_tensor( - in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - ) - else: - in0_t = torch2tt_tensor( - in0, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - ) - - if in1_in_dram: - in1_t = torch2tt_tensor( - in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - ) - else: - in1_t = torch2tt_tensor( - in1, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - ) - - output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 - bias_t = pad_by_zero( - bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - )[0] - - if in0_sharded: - in0_t = ttl.tensor.interleaved_to_sharded( - in0_t, - grid_size, - [M // grid_size[0], K // grid_size[1]], - ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - ttl.tensor.ShardOrientation.COL_MAJOR, - ) - - program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=True, - # transpose_mcast=False, - fused_activation=activation, - ) - - if has_bias: - output_t = ttl.operations.primary.matmul( - in0_t, - in1_t, - bias=bias_t, - program_config=program_config, - output_mem_config=output_mem_config, - math_fidelity=fidelity, - ) - else: - output_t = ttl.operations.primary.matmul( - in0_t, - in1_t, - program_config=program_config, - output_mem_config=output_mem_config, - math_fidelity=fidelity, - ) - - if out_sharded: - output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) - - pt_out = in0 @ in1 - - if has_bias: - pt_out = pt_out + bias - - if activation != None: - pt_out = torch.nn.functional.gelu(pt_out) - tt_out = tt2torch_tensor(output_t) - - passing, output = comp_pcc(pt_out, tt_out) - logger.info(output) - assert passing - - -@pytest.mark.skipif(is_grayskull(), reason="not tested for GS") -@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) -@pytest.mark.parametrize("fp32_acc_mode", [True, False], ids=["fp32", "no_fp32"]) -@pytest.mark.parametrize( - "fidelity", - [ - ttl.tensor.MathFidelity.LoFi, - ], - ids=["LoFi"], -) -@pytest.mark.parametrize("has_bias", [True, False], ids=["bias", "no_bias"]) @pytest.mark.parametrize( - "in1_in_dram, out_sharded, in0_sharded, M, K, N, activation", + "M, K, N", [ - # # in1-L1-fusedQKV - (False, True, True, 2688, 1024, 3072, None), # both sharded - (False, True, False, 2688, 1024, 3072, None), # out sharded, in0 interleaved - (False, False, True, 2688, 1024, 3072, None), # out interleaved, in0 sharded - (False, False, False, 2688, 1024, 3072, None), # out interleaved, in0 interleaved - # # # # in1-dram-fusedQKV - (True, True, True, 2688, 1024, 3072, None), - (True, True, False, 2688, 1024, 3072, None), - (True, False, True, 2688, 1024, 3072, None), - (True, False, False, 2688, 1024, 3072, None), - # # # # in1-L1-selfout - (False, True, True, 2688, 1024, 1024, None), - (False, True, False, 2688, 1024, 1024, None), - (False, False, True, 2688, 1024, 1024, None), - (False, False, False, 2688, 1024, 1024, None), - # # # # in1-dram-selfout - (True, True, True, 2688, 1024, 1024, None), - (True, True, False, 2688, 1024, 1024, None), - (True, False, True, 2688, 1024, 1024, None), - (True, False, False, 2688, 1024, 1024, None), - # # # # in1-L1-ff1 - (False, True, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, True, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, False, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (False, False, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - # # # # in1-dram-ff1 - (True, True, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, True, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, False, True, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - (True, False, False, 2688, 1024, 4096, (ttl.tensor.FusibleActivation.GELU, True)), - # # # # # in1-L1-ff1 - no Gelu - (False, True, True, 2688, 1024, 4096, None), - (False, True, False, 2688, 1024, 4096, None), - (False, False, True, 2688, 1024, 4096, None), - (False, False, False, 2688, 1024, 4096, None), - # # # # in1-dram-ff1 - no Gelu - (True, True, True, 2688, 1024, 4096, None), - (True, True, False, 2688, 1024, 4096, None), - (True, False, True, 2688, 1024, 4096, None), - (True, False, False, 2688, 1024, 4096, None), - # # # # in1-L1-ff2 - (False, True, True, 2688, 4096, 1024, None), - (False, True, False, 2688, 4096, 1024, None), - (False, False, True, 2688, 4096, 1024, None), - (False, False, False, 2688, 4096, 1024, None), - # # # # in1-dram-ff2 - (True, True, True, 2688, 4096, 1024, None), - (True, True, False, 2688, 4096, 1024, None), - (True, False, True, 2688, 4096, 1024, None), - (True, False, False, 2688, 4096, 1024, None), + (64, 64, 64), ], ) -@skip_for_wormhole_b0("WH ND hang, see issue #4392") def test_bert_linear_batch7( device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - has_bias, - fp32_acc_mode, - packer_l1_acc, M, K, N, - activation, function_level_defaults, ): in0_shape = [1, 1, M, K] in1_shape = [1, 1, K, N] - bias_shape = [1, 1, N] - grid_size = (8, 7) + grid_size = (1, 1) in0_block_h = M // grid_size[1] // 32 in0_block_w = K // grid_size[0] // 32 out_block_h = M // grid_size[1] // 32 out_block_w = N // grid_size[0] // 32 - if fp32_acc_mode == True: - out_subblock_w = 4 - out_subblock_h = 1 - else: - if out_block_w <= 8: - out_subblock_w = out_block_w - out_subblock_h = 8 // out_subblock_w - else: - out_subblock_h = 1 - out_subblock_w = 8 // out_subblock_h - while out_block_w % out_subblock_w != 0: - out_subblock_w = out_block_w // 2 + out_subblock_w = 1 + out_subblock_h = 1 logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) @@ -339,7 +457,6 @@ def test_bert_linear_batch7( in0 = torch.randn(in0_shape).bfloat16().float() in1 = torch.randn(in1_shape).bfloat16().float() - bias = torch.randn(bias_shape).bfloat16().float() in0_t = torch2tt_tensor( in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B @@ -348,63 +465,15 @@ def test_bert_linear_batch7( in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttl.tensor.DataType.BFLOAT8_B ) - output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 - bias_t = pad_by_zero( - bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttl.tensor.DataType.BFLOAT8_B - )[0] - - if in0_sharded: - in0_t = ttl.tensor.interleaved_to_sharded( - in0_t, - grid_size, - [M // grid_size[1], K // grid_size[0]], - ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - ttl.tensor.ShardOrientation.ROW_MAJOR, - ) - - program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=False, - fused_activation=activation, - ) + output_mem_config = interleaved_mem_config_L1 - if has_bias: - output_t = ttl.operations.primary.matmul( - in0_t, - in1_t, - bias=bias_t, - program_config=program_config, - output_mem_config=output_mem_config, - math_fidelity=fidelity, - fp32_dest_acc_en=fp32_acc_mode, - packer_l1_acc=packer_l1_acc, - ) - else: - output_t = ttl.operations.primary.matmul( - in0_t, - in1_t, - program_config=program_config, - output_mem_config=output_mem_config, - math_fidelity=fidelity, - fp32_dest_acc_en=fp32_acc_mode, - packer_l1_acc=packer_l1_acc, - ) - - if out_sharded: - output_t = ttl.tensor.sharded_to_interleaved(output_t, interleaved_mem_config_L1) + output_t = ttl.tensor.matmul( + in0_t, + in1_t, + output_mem_config=output_mem_config, + ) pt_out = in0 @ in1 - - if has_bias: - pt_out = pt_out + bias - - if activation != None: - pt_out = torch.nn.functional.gelu(pt_out) tt_out = tt2torch_tensor(output_t) passing, output = comp_pcc(pt_out, tt_out) diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_groupnorm_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/test_groupnorm_sharded.py new file mode 100644 index 00000000000..23b3f687fda --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/test_groupnorm_sharded.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +import sys +from loguru import logger +import pytest + +import torch +import tt_lib as ttl +from tt_lib.utils import ( + pad_weight, + tilize_to_list, + untilize, +) +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( + comp_equal, + comp_pcc, +) + +from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero + + +def manual_group_norm(input_tensor, num_groups, eps=1e-5): + N, C, H, W = input_tensor.shape + assert C % num_groups == 0, "Number of channels must be divisible by number of groups" + + # Reshape into groups + group_channels = C // num_groups + input_tensor = input_tensor.view(N, num_groups, group_channels, H, W) + + # Calculate mean and variance + mean = input_tensor.mean(dim=(2, 3, 4), keepdim=True) + var = input_tensor.var(dim=(2, 3, 4), keepdim=True) + + # Normalize + input_tensor = (input_tensor - mean) / torch.sqrt(var + eps) + + # Reshape back to original dimensions + input_tensor = input_tensor.view(N, C, H, W) + + return input_tensor + + +def ref_groupnorm(x, group_size, eps, **kwargs): + n_channels = x.shape[1] + lnorm = torch.nn.GroupNorm(group_size, n_channels, eps, **kwargs) + return lnorm(x) + + +@pytest.mark.parametrize( + "num_batches, C, H, W, num_groups, grid_size, shard_orientation, shard_layout", + [ + (1, 1280, 8, 8, 4, (2, 8), ttl.tensor.ShardOrientation.COL_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED), + (1, 1280, 8, 8, 8, (8, 2), ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED), + (1, 640, 16, 16, 1, (4, 5), ttl.tensor.ShardOrientation.COL_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED), + (1, 640, 16, 16, 5, (5, 4), ttl.tensor.ShardOrientation.ROW_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED), + ( + 1, + 320, + 32, + 32, + 1, + (1, 8), + ttl.tensor.ShardOrientation.COL_MAJOR, + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ), + ( + 1, + 320, + 32, + 32, + 5, + (2, 8), + ttl.tensor.ShardOrientation.COL_MAJOR, + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ), + ( + 1, + 320, + 32, + 32, + 5, + (4, 8), + ttl.tensor.ShardOrientation.COL_MAJOR, + ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ), + ], +) +@pytest.mark.parametrize( + "layout", + [(ttl.tensor.Layout.TILE), (ttl.tensor.Layout.ROW_MAJOR)], + ids=["tile", "rm"], +) +@pytest.mark.parametrize( + "test_id", + (0, 1, 2), + ids=["GN", "GN_G", "GN_GB"], +) +def test_groupnorm_sharded( + test_id, device, layout, num_batches, C, H, W, num_groups, grid_size, shard_orientation, shard_layout +): + torch.manual_seed(1234) + + out_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, ttl.tensor.BufferType.L1) + gamma_beta_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1) + in0_mem_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + + epsf = 1e-2 + + in0_shape = (1, 1, num_batches * W * H, C) + pyt_in0_shape = (num_batches, C, H, W) + pyt_in0 = torch.rand(pyt_in0_shape) * 2 - 0.95 + + if shard_layout == ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED: + if shard_orientation == ttl.tensor.ShardOrientation.COL_MAJOR: + shard_shape = [int(num_batches * W * H / grid_size[0]), int(C / grid_size[1])] + else: + shard_shape = [int(num_batches * W * H / grid_size[1]), int(C / grid_size[0])] + elif shard_layout == ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED: + shard_shape = [int(num_batches * W * H / (grid_size[1] * grid_size[0])), int(C)] + + # logger.info("shard_shape") + logger.info("shard_shape " + str(shard_shape)) + + in0 = pyt_in0.transpose(1, -1).contiguous().view(1, 1, -1, C) + + # pyt_gamma = torch.ones(in0_shape[3]) + pyt_gamma = torch.rand(in0_shape[3]) * 2 - 1 + # pyt_beta = torch.zeros(in0_shape[3]) + pyt_beta = torch.rand(in0_shape[3]) * 2.0 - 1.1 + gamma = pyt_gamma.reshape(1, 1, -1, 32) + gamma_t = ttl.tensor.Tensor( + gamma.reshape(-1).tolist(), + gamma.shape, + ttl.tensor.DataType.BFLOAT16, + ttl.tensor.Layout.ROW_MAJOR, + ).to(device, gamma_beta_mem_config) + + beta = pyt_beta.reshape(1, 1, -1, 32) + beta_t = ttl.tensor.Tensor( + beta.reshape(-1).tolist(), + beta.shape, + ttl.tensor.DataType.BFLOAT16, + ttl.tensor.Layout.ROW_MAJOR, + ).to(device, gamma_beta_mem_config) + + in0_t = torch2tt_tensor( + in0, device, tt_memory_config=in0_mem_config, tt_dtype=ttl.tensor.DataType.BFLOAT16, tt_layout=layout + ) + in0_t_sharded = ttl.tensor.interleaved_to_sharded(in0_t, grid_size, shard_shape, shard_layout, shard_orientation) + + program_config = ttl.operations.primary.GroupNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=grid_size, + out_data_format=ttl.tensor.DataType.BFLOAT16, + inplace=False if layout == ttl.tensor.Layout.ROW_MAJOR else True, + ) + + if test_id == 0: + logger.info("Running GN") + out_t = ttl.operations.primary.groupnorm( + in0_t_sharded, + num_groups, + num_batches, + epsf, + output_mem_config=out_mem_config, + program_config=program_config, + ) + if test_id == 1: + logger.info("Running GN_G") + out_t = ttl.operations.primary.groupnorm( + in0_t_sharded, + num_groups, + num_batches, + epsf, + gamma_t, + output_mem_config=out_mem_config, + program_config=program_config, + ) + if test_id == 2: + logger.info("Running LN_GB") + out_t = ttl.operations.primary.groupnorm( + in0_t_sharded, + num_groups, + num_batches, + epsf, + gamma_t, + beta_t, + output_mem_config=out_mem_config, + program_config=program_config, + ) + + out_t = ttl.tensor.sharded_to_interleaved(out_t, in0_mem_config) + out = tt2torch_tensor(out_t) + + pyt_groupnorm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=C, eps=epsf) + + if test_id == 1 or test_id == 2: + pyt_groupnorm.weight.data = pyt_gamma + if test_id == 2: + pyt_groupnorm.bias.data = pyt_beta + pyt_out = pyt_groupnorm(pyt_in0) + pyt_out = pyt_out.transpose(1, -1).contiguous().view(1, 1, -1, C) + + passing, output = comp_pcc(pyt_out, out) + logger.info(output) + assert passing diff --git a/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp b/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp index 679b9959668..3f4b20996b1 100644 --- a/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp +++ b/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.cpp @@ -6,12 +6,19 @@ #include +#include "tt_dnn/op_library/run_operation.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_eager/tt_dnn/op_library/reshape/reshape_op.hpp" #include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" +#include "tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" + +#include "third_party/magic_enum/magic_enum.hpp" using namespace tt::constants; -using namespace std; using namespace tt::tt_metal; namespace tt { @@ -56,4 +63,799 @@ Tensor groupnorm( } // namespace tt_metal +namespace operations { +namespace primary { + +inline bool is_dram(const Tensor& input_tensor) { return input_tensor.memory_config().buffer_type == BufferType::DRAM; } +inline bool is_dram(const std::optional input_tensor) { + return input_tensor.has_value() ? is_dram(input_tensor.value()) : true; +} +inline bool is_dram(const Buffer* b) { return b->buffer_type() == BufferType::DRAM; } + +int get_max_subblock(uint32_t n, uint32_t max_subblock_w) { + if (n <= max_subblock_w) { + return n; + } + + for (int quotient = max_subblock_w; quotient > 1; --quotient) { + if (n % quotient == 0) { + return quotient; + } + } + return 1; +} +bool is_rectangle_grid(const std::vector& core_coords) { + if (core_coords.empty()) { + return true; + } + + int min_x = std::numeric_limits::max(); + int max_x = std::numeric_limits::min(); + int min_y = std::numeric_limits::max(); + int max_y = std::numeric_limits::min(); + + for (const auto& coord : core_coords) { + min_x = std::min(min_x, static_cast(coord.x)); + max_x = std::max(max_x, static_cast(coord.x)); + min_y = std::min(min_y, static_cast(coord.y)); + max_y = std::max(max_y, static_cast(coord.y)); + } + + return ((max_x - min_x + 1) * (max_y - min_y + 1)) == core_coords.size(); +} +void split_and_form_rectangle_grids(std::vector& group, std::vector& mcast_group_first, std::vector& mcast_group_mid, std::vector& mcast_group_last) { + + int remove_front = 0; + int remove_back = 0; + int min_total_removal = group.size(); + + for (int front = 0; front <= group.size(); ++front) { + for (int back = 0; front + back <= group.size(); ++back) { + if (is_rectangle_grid(std::vector(group.begin() + front, group.end() - back))) { + int total_removal = front + back; + if (total_removal < min_total_removal) { + min_total_removal = total_removal; + remove_front = front; + remove_back = back; + } + } + } + } + + // Pop and push the front outliers + for (int i = 0; i < remove_front; ++i) { + mcast_group_first.push_back(mcast_group_mid.front()); + mcast_group_mid.erase(mcast_group_mid.begin()); + } + + // Pop and push the back outliers + for (int i = 0; i < remove_back; ++i) { + mcast_group_last.push_back(mcast_group_mid.back()); + mcast_group_mid.pop_back(); + } +} +operation::ProgramWithCallbacks groupnorm_sharded_( + const Tensor &a, + const std::optional gamma, + const std::optional beta, + Tensor& output, + float eps, + const uint32_t num_groups, + const uint32_t num_batches, + MathFidelity fidelity, + DataType im_data_format, + CoreCoord grid_size +) { + bool is_row_major_layout = a.layout() == Layout::ROW_MAJOR; + bool is_height_sharding = a.shape()[3] == a.shard_spec().value().shape[1]; + // convert data format + tt::DataFormat in_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype()); + tt::DataFormat out_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype()); + tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(im_data_format); + tt::DataFormat gamma_beta_cb_data_format = tt::DataFormat::Float16_b; + // tile sizes + uint32_t in_single_tile_size = tt_metal::detail::TileSize(in_data_format); + uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); + uint32_t out_single_tile_size = tt_metal::detail::TileSize(out_data_format); + uint32_t gamma_beta_single_tile_size = tt_metal::detail::TileSize(gamma_beta_cb_data_format); + // shard shape per core + uint32_t per_core_M = a.shard_spec().value().shape[0]; + uint32_t per_core_N = a.shard_spec().value().shape[1]; + uint32_t per_core_Mt = per_core_M / TILE_HEIGHT; + uint32_t per_core_Nt = per_core_N / TILE_WIDTH; + // tensor shape + const auto shape = a.shape(); + uint32_t H = shape[2]; + uint32_t Ht = H / TILE_HEIGHT; + uint32_t W = shape[3]; + uint32_t Wt = W / TILE_WIDTH; + // grid + uint32_t num_cores_c = grid_size.x; + uint32_t num_cores_r = grid_size.y; + uint32_t num_cores = num_cores_c * num_cores_r; + auto shard_orientation = a.shard_spec().value().orientation; + // split each batch into multiple cores + uint32_t num_shards_r = H / per_core_M; + uint32_t num_cores_per_batch = num_batches > num_shards_r ? 1 : num_shards_r / num_batches; + uint32_t num_shards_c = W / per_core_N; + uint32_t num_cores_per_group = num_groups > num_shards_c ? 1 : num_shards_c / num_groups; + // each core contains multiple batches + uint32_t num_batches_per_core = num_batches > num_shards_r ? num_batches / num_shards_r : 1; + uint32_t num_groups_per_core = num_groups > num_shards_c ? num_groups / num_shards_c : 1; + + // subblock + uint32_t block_wt = per_core_Nt / num_groups_per_core; + uint32_t block_ht = per_core_Mt / num_batches_per_core; + uint32_t block_w = block_wt * TILE_WIDTH; + uint32_t block_h = block_ht * TILE_HEIGHT; + uint32_t subblock_wt = get_max_subblock(block_wt, 8); + uint32_t num_subblocks_w = block_wt / subblock_wt; + + log_debug(tt::LogOp, "num_batches: {}", num_batches); + log_debug(tt::LogOp, "num_groups: {}", num_groups); + log_debug(tt::LogOp, "num_cores_r: {}", num_cores_r); + log_debug(tt::LogOp, "num_cores_c: {}", num_cores_c); + log_debug(tt::LogOp, "num_cores_per_batch: {}", num_cores_per_batch); + log_debug(tt::LogOp, "num_cores_per_group: {}", num_cores_per_group); + log_debug(tt::LogOp, "num_batches_per_core: {}", num_batches_per_core); + log_debug(tt::LogOp, "num_groups_per_core: {}", num_groups_per_core); + log_debug(tt::LogOp, "block_wt: {}", block_wt); + log_debug(tt::LogOp, "block_ht: {}", block_ht); + log_debug(tt::LogOp, "block_w: {}", block_w); + log_debug(tt::LogOp, "block_h: {}", block_h); + log_debug(tt::LogOp, "subblock_wt: {}", subblock_wt); + log_debug(tt::LogOp, "num_subblocks_w: {}", num_subblocks_w); + + TT_ASSERT(W % num_groups == 0 && "tensor width must be divisible by num_groups!"); + TT_ASSERT((W / num_groups) % TILE_WIDTH == 0 && "group width must be divisible by tile width!"); + if (shard_orientation == ShardOrientation::ROW_MAJOR and num_groups_per_core == 1) { + TT_ASSERT(num_cores_c % num_groups == 0 && "for RM shard, when each group is split across cores, num_cores_c must be divisible by num_groups!"); + } else if (shard_orientation == ShardOrientation::COL_MAJOR and num_groups_per_core == 1) { + TT_ASSERT(num_cores_r % num_groups == 0 && "for CM shard, when each group is split across cores, num_cores_c must be divisible by num_groups!"); + } + + if (shard_orientation == ShardOrientation::ROW_MAJOR and num_batches_per_core == 1) { + TT_ASSERT(num_cores_r % num_batches == 0 && "for RM shard, when each batch is split across cores, num_cores_r must be divisible by num_batches!"); + } else if (shard_orientation == ShardOrientation::COL_MAJOR and num_groups_per_core == 1) { + TT_ASSERT(num_cores_c % num_batches == 0 && "for CM shard, when each batch is split across cores, num_cores_c must be divisible by num_batches!"); + } + + // get sharded addr + auto in0_addr = a.buffer()->address(); + auto out_addr = output.buffer()->address(); + // gamma, beta addr + auto gamma_dram_addr = gamma.has_value() ? gamma.value().buffer()->address() : 0; + auto beta_dram_addr = beta.has_value() ? beta.value().buffer()->address() : 0; + // num tiles for a, gamma, beta + uint32_t num_tiles = a.volume()/TILE_HW; + uint32_t num_gamma_tiles = gamma.has_value() ? gamma.value().volume()/TILE_HW : 0; + uint32_t num_beta_tiles = beta.has_value() ? beta.value().volume()/TILE_HW : 0; + + //////////////////////////////////////////////////////////////////////////// + // Grayskull Device Setup + //////////////////////////////////////////////////////////////////////////// + Device *device = a.device(); + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + // block size for in0 (tensor a) + uint32_t in0_block_tiles = per_core_Mt * per_core_Nt; + uint32_t in0_CB_tiles = in0_block_tiles; + uint32_t in0_CB_size = in0_CB_tiles * in_single_tile_size; + // in2 - scaler + uint32_t in2_CB_size = single_tile_size; + // in3 - eps + uint32_t in3_CB_size = single_tile_size; + // gamma + uint32_t in5_CB_size = in0_block_tiles * gamma_beta_single_tile_size / per_core_Mt; + // beta + uint32_t in6_CB_size = in0_block_tiles * gamma_beta_single_tile_size / per_core_Mt; + // itermediate buffers change later + uint32_t x_CB_size = in0_block_tiles * single_tile_size; + uint32_t xmm_CB_size = in0_block_tiles * single_tile_size; + uint32_t ex_partial_CB_size = num_groups_per_core * num_batches_per_core * single_tile_size; + uint32_t ex_CB_size = ex_partial_CB_size; + uint32_t ex_global_CB_size = ex_partial_CB_size; + uint32_t ex_external_CB_size = ex_partial_CB_size; + uint32_t xmm2_CB_size = in0_block_tiles * single_tile_size; + uint32_t ex2pe_CB_size = ex_partial_CB_size; + // output buffer size + uint32_t out_CB_size = in0_block_tiles * out_single_tile_size; + + //////////////////////////////////////////////////////////////////////////// + // Application Setup + //////////////////////////////////////////////////////////////////////////// + Program program = Program(); + // define core ranges + bool use_mcast = num_cores_per_batch > 1 or num_cores_per_group > 1; + uint32_t start_core_x = 0; + uint32_t start_core_y = 0; + + CoreRange all_cores{ + .start={(std::size_t) start_core_x, (std::size_t) start_core_y}, + .end={(std::size_t) start_core_x + num_cores_c - 1, (std::size_t) start_core_y + num_cores_r - 1}}; + // create a vector of cores, in either RM or CM + std::vector core_coords; + for (int i=0; i < num_cores_r * num_cores_c; ++i) { + if (shard_orientation == ShardOrientation::ROW_MAJOR) { + core_coords.push_back(CoreCoord{i % num_cores_c, i / num_cores_c}); + } else { + core_coords.push_back(CoreCoord{i / num_cores_r, i % num_cores_r}); + } + } + std::vector > core_coords2D; + if (shard_orientation == ShardOrientation::ROW_MAJOR) { + for (int i=0; i < num_cores_c / num_cores_per_group; ++i) { + for (int j=0; j < num_cores_r; ++j) { + std::vector temp; + for (int k=0; k < num_cores_per_group; ++k) { + temp.push_back(CoreCoord{(std::size_t)(k + i * num_cores_per_group), (std::size_t)j}); + } + core_coords2D.push_back(temp); + } + } + } else { + for (int i=0; i < num_cores_r / num_cores_per_group; ++i) { + for (int j=0; j < num_cores_c; ++j) { + std::vector temp; + for (int k=0; k < num_cores_per_group; ++k) { + temp.push_back(CoreCoord{(std::size_t)j, (std::size_t)(k + i * num_cores_per_group)}); + } + core_coords2D.push_back(temp); + } + } + } + + // one mcast core per batch per group + std::set mcast_sender_core_ranges; + std::set mcast_receiver_core_ranges; + uint32_t core_index = 0; + uint32_t core_index_offset = 0; + for (int i=0; i < num_batches / num_batches_per_core; ++i) { + uint32_t core_index = core_index_offset; + for (int j=0; j < num_groups / num_groups_per_core; ++j) { + mcast_sender_core_ranges.insert(CoreRange{core_coords[core_index]}); + core_index += num_cores_per_group; + core_index_offset += num_cores_per_batch * num_cores_per_group; + } + } + for (int i=0; i < num_cores_r * num_cores_c; ++i) { + // not found in mcast sender + if (mcast_sender_core_ranges.find(CoreRange{core_coords[i]}) == mcast_sender_core_ranges.end()) { + mcast_receiver_core_ranges.insert(CoreRange{core_coords[i]}); + } + } + CoreRangeSet mcast_sender_cores = CoreRangeSet(mcast_sender_core_ranges); + CoreRangeSet mcast_receiver_cores = CoreRangeSet(mcast_receiver_core_ranges); + // mcast groups + std::vector > mcast_groups; + int group_index = -1; + if (is_height_sharding) { + for (int i=0; i < num_cores_r * num_cores_c; ++i) { + if (mcast_sender_core_ranges.find(CoreRange{core_coords[i]}) != mcast_sender_core_ranges.end()) { + group_index += 1; + } + if (group_index >= mcast_groups.size()) { + mcast_groups.push_back(std::vector()); // Add a new group + } + mcast_groups[group_index].push_back(core_coords[i]); + } + } else { + for (int i=0; i < core_coords2D.size(); ++i) { + for (int j=0; j < core_coords2D[i].size(); ++j) { + if (mcast_sender_core_ranges.find(CoreRange{core_coords2D[i][j]}) != mcast_sender_core_ranges.end()) { + group_index += 1; + } + if (group_index >= mcast_groups.size()) { + mcast_groups.push_back(std::vector()); // Add a new group + } + mcast_groups[group_index].push_back(core_coords2D[i][j]); + } + } + } + // how many cores in a mcast group + uint32_t num_cores_per_mcast_group = mcast_groups[0].size(); + // Mcast args + auto reduce_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); + auto reduce_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); + // reader defines + std::map reader_mcast_sender_defines; + std::map reader_mcast_receiver_defines; + if (gamma.has_value()) { + reader_mcast_sender_defines["FUSE_GAMMA"] = "1"; + reader_mcast_receiver_defines["FUSE_GAMMA"] = "1"; + } + if (beta.has_value()) { + reader_mcast_sender_defines["FUSE_BETA"] = "1"; + reader_mcast_receiver_defines["FUSE_BETA"] = "1"; + } + // reader compile time args + std::vector reader_mcast_sender_compile_time_args = { + (std::uint32_t) reduce_receiver_semaphore, + (std::uint32_t) reduce_sender_semaphore, + (std::uint32_t) num_cores_per_mcast_group, + (std::uint32_t) num_batches_per_core * num_groups_per_core + }; + std::vector reader_mcast_receiver_compile_time_args = { + (std::uint32_t) reduce_receiver_semaphore, + (std::uint32_t) reduce_sender_semaphore, + (std::uint32_t) num_batches_per_core * num_groups_per_core + }; + // reader kernel + auto reader_mcast_sender_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_sender_unary_sharded_gn.cpp", + mcast_sender_cores, + tt_metal::ReaderDataMovementConfig{.compile_args = reader_mcast_sender_compile_time_args, .defines = reader_mcast_sender_defines} + ); + KernelHandle reader_mcast_receiver_kernels_id = -1; + if (use_mcast) { + reader_mcast_receiver_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_gn.cpp", + mcast_receiver_cores, + tt_metal::ReaderDataMovementConfig{.compile_args = reader_mcast_receiver_compile_time_args, .defines = reader_mcast_receiver_defines} + ); + } + + // writer defines + std::map writer_defines; + // writer compile time args + std::vector writer_mcast_sender_compile_time_args = { + 1, + (std::uint32_t) gamma.has_value(), + (std::uint32_t) beta.has_value(), + (std::uint32_t) is_dram(gamma), + (std::uint32_t) is_dram(beta), + (std::uint32_t) per_core_Nt + }; + std::vector writer_mcast_receiver_compile_time_args = { + 0, + (std::uint32_t) gamma.has_value(), + (std::uint32_t) beta.has_value(), + (std::uint32_t) is_dram(gamma), + (std::uint32_t) is_dram(beta), + (std::uint32_t) per_core_Nt + }; + + if (gamma.has_value() and gamma.value().layout() == Layout::ROW_MAJOR) { + auto gamma_stick_size = gamma.value().shape()[3] * gamma.value().element_size(); + bool gamma_stick_size_is_power_of_two = is_power_of_two_at_least_32(gamma_stick_size); + writer_mcast_sender_compile_time_args.push_back((std::uint32_t) gamma_stick_size_is_power_of_two); + writer_mcast_receiver_compile_time_args.push_back((std::uint32_t) gamma_stick_size_is_power_of_two); + if (gamma_stick_size_is_power_of_two) { + uint32_t gamma_log2_stick_size = gamma_stick_size_is_power_of_two ? (std::uint32_t)std::log2(gamma_stick_size) : 0; + writer_mcast_sender_compile_time_args.push_back((std::uint32_t) gamma_log2_stick_size); + writer_mcast_receiver_compile_time_args.push_back((std::uint32_t) gamma_log2_stick_size); + } else { + writer_mcast_sender_compile_time_args.push_back(gamma_stick_size); + writer_mcast_receiver_compile_time_args.push_back(gamma_stick_size); + } + } else if (beta.has_value() and beta.value().layout() == Layout::ROW_MAJOR) { + auto beta_stick_size = beta.value().shape()[3] * beta.value().element_size(); + bool beta_stick_size_is_power_of_two = is_power_of_two_at_least_32(beta_stick_size); + writer_mcast_sender_compile_time_args.push_back((std::uint32_t) beta_stick_size_is_power_of_two); + writer_mcast_receiver_compile_time_args.push_back((std::uint32_t) beta_stick_size_is_power_of_two); + if (beta_stick_size_is_power_of_two) { + uint32_t beta_log2_stick_size = beta_stick_size_is_power_of_two ? (std::uint32_t)std::log2(beta_stick_size) : 0; + writer_mcast_sender_compile_time_args.push_back((std::uint32_t) beta_log2_stick_size); + writer_mcast_receiver_compile_time_args.push_back((std::uint32_t) beta_log2_stick_size); + } else { + writer_mcast_sender_compile_time_args.push_back(beta_stick_size); + writer_mcast_receiver_compile_time_args.push_back(beta_stick_size); + + } + } else { + writer_mcast_sender_compile_time_args.push_back(0); + writer_mcast_sender_compile_time_args.push_back(0); + writer_mcast_receiver_compile_time_args.push_back(0); + writer_mcast_receiver_compile_time_args.push_back(0); + } + + // writer kernel + bool use_row_major_kernel = (gamma.has_value() and gamma.value().layout() == Layout::ROW_MAJOR) or (beta.has_value() and beta.value().layout() == Layout::ROW_MAJOR); + std::string writer_kernel = use_row_major_kernel ? "tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn_rm_gb.cpp" : "tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn.cpp"; + auto writer_mcast_sender_kernels_id = CreateKernel( + program, + writer_kernel, + mcast_sender_cores, + tt_metal::WriterDataMovementConfig{.compile_args = writer_mcast_sender_compile_time_args, .defines = writer_defines} + ); + auto writer_mcast_receiver_kernels_id = CreateKernel( + program, + writer_kernel, + mcast_receiver_cores, + tt_metal::WriterDataMovementConfig{.compile_args = writer_mcast_receiver_compile_time_args, .defines = writer_defines} + ); + // defines + std::map eltwise_binary_defines; + // compute kernel compile time args + std::vector mcast_sender_compute_compile_time_args = { + 1, + gamma.has_value(), + beta.has_value(), + num_cores_per_mcast_group, + num_batches_per_core, + num_groups_per_core, + num_batches_per_core * num_groups_per_core, + block_ht, + block_wt, + block_ht * block_wt, + subblock_wt, + num_subblocks_w, + is_row_major_layout, + per_core_Mt, + per_core_Nt, + per_core_Mt * per_core_Nt, + num_batches_per_core * block_ht * block_wt, + }; + std::vector mcast_receiver_compute_compile_time_args = { + 0, + gamma.has_value(), + beta.has_value(), + num_cores_per_mcast_group, + num_batches_per_core, + num_groups_per_core, + num_batches_per_core * num_groups_per_core, + block_ht, + block_wt, + block_ht * block_wt, + subblock_wt, + num_subblocks_w, + is_row_major_layout, + per_core_Mt, + per_core_Nt, + per_core_Mt * per_core_Nt, + num_batches_per_core * block_ht * block_wt, + }; + // compute kernel + bool fp32_dest_acc_en = false; + bool math_approx_mode = true; + auto mcast_sender_compute_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp", + mcast_sender_cores, + tt_metal::ComputeConfig{.math_fidelity = fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = mcast_sender_compute_compile_time_args, .defines = eltwise_binary_defines} + ); + auto mcast_receiver_compute_kernels_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp", + mcast_receiver_cores, + tt_metal::ComputeConfig{.math_fidelity = fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .math_approx_mode = math_approx_mode, .compile_args = mcast_receiver_compute_compile_time_args, .defines = eltwise_binary_defines} + ); + // Create circular buffers + // in0 sharded + uint32_t in0_cb_index = CB::c_in0; + tt_metal::CircularBufferConfig in0_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{in0_cb_index, in_data_format}}) + .set_page_size(in0_cb_index, in_single_tile_size).set_globally_allocated_address(*a.buffer()); + auto cb_in0 = tt_metal::CreateCircularBuffer(program, all_cores, in0_cb_config); + // in2 scaler + uint32_t in2_cb_index = CB::c_in2; + tt_metal::CircularBufferConfig in2_cb_config = tt_metal::CircularBufferConfig(in2_CB_size, {{in2_cb_index, cb_data_format}}) + .set_page_size(in2_cb_index, single_tile_size); + auto cb_in2 = tt_metal::CreateCircularBuffer(program, all_cores, in2_cb_config); + // in4 scaler-c + uint32_t in4_cb_index = CB::c_in4; + tt_metal::CircularBufferConfig in4_cb_config = tt_metal::CircularBufferConfig(in2_CB_size, {{in4_cb_index, cb_data_format}}) + .set_page_size(in4_cb_index, single_tile_size); + auto cb_in4 = tt_metal::CreateCircularBuffer(program, all_cores, in4_cb_config); + // in3 eps + uint32_t in3_cb_index = CB::c_in3; + tt_metal::CircularBufferConfig in3_cb_config = tt_metal::CircularBufferConfig(in3_CB_size, {{in3_cb_index, cb_data_format}}) + .set_page_size(in3_cb_index, single_tile_size); + auto cb_in3 = tt_metal::CreateCircularBuffer(program, all_cores, in3_cb_config); + // gamma + if (gamma.has_value()) { + uint32_t in5_cb_index = CB::c_in5; + tt_metal::CircularBufferConfig in5_cb_config = tt_metal::CircularBufferConfig(in5_CB_size, {{in5_cb_index, gamma_beta_cb_data_format}}) + .set_page_size(in5_cb_index, gamma_beta_single_tile_size); + auto cb_in5 = tt_metal::CreateCircularBuffer(program, all_cores, in5_cb_config); + } + // beta + if (beta.has_value()) { + uint32_t in6_cb_index = CB::c_in6; + tt_metal::CircularBufferConfig in6_cb_config = tt_metal::CircularBufferConfig(in6_CB_size, {{in6_cb_index, gamma_beta_cb_data_format}}) + .set_page_size(in6_cb_index, gamma_beta_single_tile_size); + auto cb_in6 = tt_metal::CreateCircularBuffer(program, all_cores, in6_cb_config); + } + // x + uint32_t x_cb_index; + x_cb_index = CB::c_intermed0; + tt_metal::CircularBufferConfig x_cb_config = tt_metal::CircularBufferConfig(x_CB_size, {{x_cb_index, cb_data_format}}) + .set_page_size(x_cb_index, single_tile_size); + auto cb_x = tt_metal::CreateCircularBuffer(program, all_cores, x_cb_config); + // xmm + uint32_t xmm_cb_index; + xmm_cb_index = CB::c_intermed1; + tt_metal::CircularBufferConfig xmm_cb_config = tt_metal::CircularBufferConfig(xmm_CB_size, {{xmm_cb_index, cb_data_format}}) + .set_page_size(xmm_cb_index, single_tile_size); + auto cb_xmm = tt_metal::CreateCircularBuffer(program, all_cores, xmm_cb_config); + // ex_partial + uint32_t ex_cb_partial_index = CB::dataflow0; + tt_metal::CircularBufferConfig ex_cb_partial_config = tt_metal::CircularBufferConfig(ex_partial_CB_size, {{ex_cb_partial_index, cb_data_format}}) + .set_page_size(ex_cb_partial_index, single_tile_size); + auto cb_ex_partial = tt_metal::CreateCircularBuffer(program, all_cores, ex_cb_partial_config); + // ex + uint32_t ex_cb_index = CB::dataflow1; + // ex_external + uint32_t ex_cb_external_index = CB::dataflow2; + tt_metal::CircularBufferConfig ex_cb_external_config = tt_metal::CircularBufferConfig(ex_external_CB_size, {{ex_cb_external_index, cb_data_format}}) + .set_page_size(ex_cb_external_index, single_tile_size); + auto cb_ex_external = tt_metal::CreateCircularBuffer(program, all_cores, ex_cb_external_config); + // ex_global + uint32_t ex_global_cb_index = CB::dataflow7; + std::map ex_global_cb_data_format_spec { + {ex_global_cb_index, cb_data_format}, + {ex_cb_index, cb_data_format} + }; + auto ex_global_cb_config = tt_metal::CircularBufferConfig(ex_global_CB_size, ex_global_cb_data_format_spec) + .set_page_size(ex_global_cb_index, single_tile_size) + .set_page_size(ex_cb_index, single_tile_size); + auto cb_ex_global = tt_metal::CreateCircularBuffer(program, all_cores, ex_global_cb_config); + // ex2pe + uint32_t cb_ex2pe_index; + cb_ex2pe_index = CB::c_intermed3; + tt_metal::CircularBufferConfig ex2pe_cb_config = tt_metal::CircularBufferConfig(ex2pe_CB_size, {{cb_ex2pe_index, cb_data_format}}) + .set_page_size(cb_ex2pe_index, single_tile_size); + auto cb_ex2pe = tt_metal::CreateCircularBuffer(program, all_cores, ex2pe_cb_config); + // out + uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 + tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, {{output_cb_index, out_data_format}}) + .set_page_size(output_cb_index, out_single_tile_size).set_globally_allocated_address(*output.buffer()); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); + + // Runtime Args + std::vector writer_kernel_ids; + float winv = 1.0f / std::sqrt(block_h * block_w); // bcast-w scaler + float cinv = 1.0f / std::sqrt(num_cores_per_batch * num_cores_per_group); // bcast-cores scaler + bfloat16 bfloat_cinv_value = bfloat16(cinv); + uint32_t packed_cinv_value = pack_two_bfloat16_into_uint32({bfloat_cinv_value, bfloat_cinv_value}); + bfloat16 bfloat_winv_value = bfloat16(winv); + uint32_t packed_winv_value = pack_two_bfloat16_into_uint32({bfloat_winv_value, bfloat_winv_value}); + union { float f; uint32_t u; } e; e.f = eps; + + uint32_t gamma_tile_start_id = 0; + uint32_t beta_tile_start_id = 0; + for (int i=0; i < mcast_groups.size(); ++i) { + auto group = mcast_groups[i]; + bool rectangle_grid = is_rectangle_grid(group); + + for (int j=0; j < group.size(); ++j) { + CoreCoord core = group[j]; + CoreCoord core_physical = device->worker_core_from_logical_core(core); + + if (j == 0) { // mcast sender + // get the bounding box for the mcast + std::vector mcast_group_first; + std::vector mcast_group_mid(group); + std::vector mcast_group_last; + if (not rectangle_grid) { + split_and_form_rectangle_grids(group, mcast_group_first, mcast_group_mid, mcast_group_last); + } + + CoreCoord mcast_start = device->worker_core_from_logical_core(mcast_group_mid.front()); + CoreCoord mcast_end = device->worker_core_from_logical_core(mcast_group_mid.back()); + + if ((mcast_start.x < mcast_end.x) or (mcast_start.y < mcast_end.y)) { + std::swap(mcast_start, mcast_end); + } + + std::vector mcast_sender_args; + mcast_sender_args.push_back(not mcast_group_first.empty()); + mcast_sender_args.push_back(not mcast_group_last.empty()); + mcast_sender_args.push_back(mcast_start.x); + mcast_sender_args.push_back(mcast_start.y); + mcast_sender_args.push_back(mcast_end.x); + mcast_sender_args.push_back(mcast_end.y); + if (not mcast_group_first.empty()) { + mcast_sender_args.push_back(mcast_group_mid.size()); + } else { + mcast_sender_args.push_back(mcast_group_mid.size() - 1); // mcast w/o itself + } + + if (not mcast_group_first.empty()) { + CoreCoord mcast_first_start = device->worker_core_from_logical_core(mcast_group_first.front()); + CoreCoord mcast_first_end = device->worker_core_from_logical_core(mcast_group_first.back()); + + if ((mcast_first_start.x < mcast_first_end.x) or (mcast_first_start.y < mcast_first_end.y)) { + std::swap(mcast_first_start, mcast_first_end); + } + + mcast_sender_args.push_back(mcast_first_start.x); + mcast_sender_args.push_back(mcast_first_start.y); + mcast_sender_args.push_back(mcast_first_end.x); + mcast_sender_args.push_back(mcast_first_end.y); + mcast_sender_args.push_back(mcast_group_first.size() - 1); // mcast w/0 itself + } + if (not mcast_group_last.empty()) { + CoreCoord mcast_last_start = device->worker_core_from_logical_core(mcast_group_last.front()); + CoreCoord mcast_last_end = device->worker_core_from_logical_core(mcast_group_last.back()); + + if ((mcast_last_start.x < mcast_last_end.x) or (mcast_last_start.y < mcast_last_end.y)) { + std::swap(mcast_last_start, mcast_last_end); + } + + mcast_sender_args.push_back(mcast_last_start.x); + mcast_sender_args.push_back(mcast_last_start.y); + mcast_sender_args.push_back(mcast_last_end.x); + mcast_sender_args.push_back(mcast_last_end.y); + mcast_sender_args.push_back(mcast_group_last.size()); + } + + // add all coords within a group + std::vector mcast_noc_xy; + for (int c=0; c < group.size(); ++c) { + CoreCoord coord = device->worker_core_from_logical_core(group[c]); + mcast_noc_xy.push_back(coord.x); + } + for (int c=0; c < group.size(); ++c) { + CoreCoord coord = device->worker_core_from_logical_core(group[c]); + mcast_noc_xy.push_back(coord.y); + } + mcast_sender_args.insert(mcast_sender_args.end(), mcast_noc_xy.begin(), mcast_noc_xy.end()); + tt_metal::SetRuntimeArgs(program, reader_mcast_sender_kernels_id, core, mcast_sender_args); + + } else { // mcast receiver + std::vector mcast_receiver_args; + mcast_receiver_args.push_back(device->worker_core_from_logical_core(group.front()).x); + mcast_receiver_args.push_back(device->worker_core_from_logical_core(group.front()).y); + tt_metal::SetRuntimeArgs(program, reader_mcast_receiver_kernels_id, core, mcast_receiver_args); + } + + } + } + + // writer + for (int i=0; i < core_coords.size(); ++i) { + + auto core = core_coords[i]; + + std::vector writer_mcast_sender_args; + writer_mcast_sender_args.push_back(packed_cinv_value); + writer_mcast_sender_args.push_back(packed_winv_value); + writer_mcast_sender_args.push_back(e.u); + writer_mcast_sender_args.push_back(gamma_dram_addr); + writer_mcast_sender_args.push_back(beta_dram_addr); + writer_mcast_sender_args.push_back(gamma_tile_start_id); + writer_mcast_sender_args.push_back(beta_tile_start_id); + tt_metal::SetRuntimeArgs(program, writer_mcast_receiver_kernels_id, core, writer_mcast_sender_args); + writer_kernel_ids.push_back(writer_mcast_receiver_kernels_id); + + gamma_tile_start_id = (gamma_tile_start_id + per_core_Nt) % Wt; + beta_tile_start_id = (beta_tile_start_id + per_core_Nt) % Wt; + } + + auto override_runtime_args_callback = [ + writer_kernel_ids, + cb_in0, + cb_output, + num_cores, + grid_size + ] + ( + const void* operation, + Program &program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + auto src_buffer_a = input_tensors.at(0).buffer(); + auto gamma_tensor = optional_input_tensors.at(0); + auto beta_tensor = optional_input_tensors.at(1); + auto dst_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, cb_in0, *src_buffer_a); + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + + auto writer_kernel_id = writer_kernel_ids.at(i); + + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + + if (gamma_tensor.has_value()) { + runtime_args[3] = gamma_tensor.value().buffer()->address(); + } + if (beta_tensor.has_value()) { + runtime_args[4] = beta_tensor.value().buffer()->address(); + } + } + }; + + return {std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} +void GroupNorm::validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const { + TT_FATAL(input_tensors.size() == 1 and optional_input_tensors.size() <= 2, "Must have between 1 to 3 input tensors"); + auto& a = input_tensors.at(0); + const auto& gamma = optional_input_tensors.at(0); + const auto& beta = optional_input_tensors.at(1); + TT_FATAL(a.layout() == Layout::TILE or a.layout() == Layout::ROW_MAJOR); + TT_FATAL(a.dtype() == DataType::BFLOAT16 or a.dtype() == DataType::BFLOAT8_B); + TT_FATAL(a.storage_type() == StorageType::DEVICE, "Operands to layernorm need to be on device!"); + TT_FATAL(a.buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(a.shape()[3] % this->num_groups == 0, "channel must be divisible by num_groups!"); + + if (gamma.has_value()) { + if (gamma.value().layout() == Layout::TILE) { + TT_FATAL(a.shape()[3] == gamma.value().shape()[3], fmt::format("{} != {}", a.shape()[3], gamma.value().shape()[3])); + TT_FATAL(a.device() == gamma.value().device()); + TT_FATAL(gamma.value().buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(gamma.value().shape()[2] == TILE_HEIGHT); + } else { + TT_FATAL(gamma.value().layout() == Layout::ROW_MAJOR); + TT_FATAL((gamma.value().shape()[3] == TILE_WIDTH && gamma.value().volume() / TILE_WIDTH == a.shape()[3] / TILE_WIDTH)); + TT_FATAL(a.device() == gamma.value().device()); + TT_FATAL(gamma.value().buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(gamma.value().dtype() == DataType::BFLOAT16); + } + if (beta.has_value()) { + TT_FATAL(gamma.value().layout() == beta.value().layout()); + } + } + + if (beta.has_value()) { + if (beta.value().layout() == Layout::TILE) { + TT_FATAL(a.shape()[3] == beta.value().shape()[3]); + TT_FATAL(a.device() == beta.value().device()); + TT_FATAL(beta.value().buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(beta.value().shape()[2] == TILE_HEIGHT); + } else { + TT_FATAL(beta.value().layout() == Layout::ROW_MAJOR); + TT_FATAL((beta.value().shape()[3] == TILE_WIDTH && beta.value().volume() / TILE_WIDTH == a.shape()[3] / TILE_WIDTH)); + TT_FATAL(a.device() == beta.value().device()); + TT_FATAL(beta.value().buffer() != nullptr, "Operands to layernorm need to be allocated in buffers on device!"); + TT_FATAL(beta.value().dtype() == DataType::BFLOAT16); + } + } +} +std::vector GroupNorm::compute_output_shapes(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + return {input_tensor.shape()}; +} +std::vector GroupNorm::create_output_tensors(const std::vector &input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + if (this->program_config.inplace) { + return {input_tensors.at(0)}; + } else { + auto mem_config = this->output_mem_config; + mem_config.shard_spec = input_tensor.shard_spec(); + return {create_sharded_device_tensor(this->compute_output_shapes(input_tensors).at(0), program_config.out_data_format, Layout::TILE, input_tensor.device(), mem_config)}; + } +} +operation::ProgramWithCallbacks GroupNorm::create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors +) const { + const auto& a = input_tensors.at(0); + const auto& gamma = optional_input_tensors.at(0); + const auto& beta = optional_input_tensors.at(1); + auto& output_tensor = output_tensors.at(0); + + MathFidelity fidelity = this->program_config.math_fidelity; + uint32_t num_cores_x = this->program_config.compute_with_storage_grid_size.x; + uint32_t num_cores_y = this->program_config.compute_with_storage_grid_size.y; + CoreCoord grid_size = CoreCoord(num_cores_x, num_cores_y); + + return groupnorm_sharded_( + a, gamma, beta, output_tensor, this->eps, + this->num_groups, this->batch, + fidelity, + program_config.im_data_format, + program_config.compute_with_storage_grid_size + ); +} +tt::stl::reflection::Attributes GroupNorm::attributes() const { + return { + {"eps", this->eps}, + {"num_groups", this->num_groups}, + {"output_mem_config", this->output_mem_config} + }; +} + +} // namespace primary +} // namespace operations + } // namespace tt diff --git a/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp b/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp index 61855da1bcb..3014fe32e6d 100644 --- a/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp +++ b/tt_eager/tt_dnn/op_library/groupnorm/groupnorm_op.hpp @@ -38,4 +38,45 @@ Tensor groupnorm( const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); } // namespace tt_metal + + +namespace operations { + + using namespace tt_metal; +namespace primary { +struct GroupNormShardedMultiCoreProgramConfig { + CoreCoord compute_with_storage_grid_size; + MathFidelity math_fidelity; + DataType im_data_format; + DataType out_data_format; + bool inplace; + + tt::stl::reflection::Attributes attributes() const; +}; + +struct GroupNorm { + float eps; + uint32_t num_groups; + uint32_t batch; + MemoryConfig output_mem_config; + tt::operations::primary::GroupNormShardedMultiCoreProgramConfig program_config; + + void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector &output_tensors + ) const; + tt::stl::reflection::Attributes attributes() const; +}; + +inline Tensor groupnorm(const Tensor &a, const uint32_t num_groups, uint32_t batch, float eps, std::optional gamma = std::nullopt, std::optional beta = std::nullopt, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const GroupNormShardedMultiCoreProgramConfig& program_config = GroupNormShardedMultiCoreProgramConfig{}) { + return operation::run(GroupNorm{.eps=eps, .num_groups=num_groups, .batch=batch, .output_mem_config=output_mem_config, .program_config=program_config}, {a}, {gamma, beta}).at(0); +} + + +} // namespace primary +} // namespace operations } // namespace tt diff --git a/tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp b/tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp new file mode 100644 index 00000000000..b10a68403cb --- /dev/null +++ b/tt_eager/tt_dnn/op_library/groupnorm/kernels/compute/groupnorm_sharded.cpp @@ -0,0 +1,461 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#define REDUCE_OP PoolType::SUM +#define REDUCE_DIM ReduceDim::REDUCE_SCALAR + +#define BCAST_LLKOP EltwiseBinaryType::ELWMUL +#define BCAST_DIM BroadcastType::COL + +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/layernorm.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/untilize.h" +#include "compute_kernel_api/matmul.h" + +// #include "debug/dprint.h" + + +inline void tilize_in( + uint32_t in_cb_id, + uint32_t out_cb_id, + uint32_t block_h, + uint32_t block_w +) { + tilize_init_short(in_cb_id, block_w); + for (uint32_t h = 0; h < block_h; ++h) { + cb_reserve_back(out_cb_id, block_w); + tilize_block(in_cb_id, block_w, out_cb_id); + cb_push_back(out_cb_id, block_w); + cb_pop_front(in_cb_id, block_w); + } + tilize_uninit(); +} + +inline void untilize_out( + uint32_t in_cb_id, + uint32_t out_cb_id, + uint32_t block_h, + uint32_t block_w +) { + untilize_init_short(in_cb_id); + for (uint32_t h = 0; h < block_h; ++h) { + cb_wait_front(in_cb_id, block_w); + cb_reserve_back(out_cb_id, block_w); + untilize_block(in_cb_id, block_w, out_cb_id); + cb_pop_front(in_cb_id, block_w); + cb_push_back(out_cb_id, block_w); + } + untilize_uninit(in_cb_id); +} + +// SPLIT REDUCE across Cores +namespace NAMESPACE { +void MAIN { + + constexpr uint32_t is_mcast_sender = get_compile_time_arg_val(0); + constexpr uint32_t do_gamma = get_compile_time_arg_val(1); + constexpr uint32_t do_beta = get_compile_time_arg_val(2); + constexpr uint32_t num_cores_per_mcast_group = get_compile_time_arg_val(3); + + constexpr uint32_t batch = get_compile_time_arg_val(4); + constexpr uint32_t group = get_compile_time_arg_val(5); + + constexpr uint32_t num_batch_group = get_compile_time_arg_val(6); + + volatile uint32_t block_h = get_compile_time_arg_val(7); + constexpr uint32_t block_w = get_compile_time_arg_val(8); + constexpr uint32_t block_hw = get_compile_time_arg_val(9); + + constexpr uint32_t subblock_w = get_compile_time_arg_val(10); + constexpr uint32_t num_subblocks_w = get_compile_time_arg_val(11); + + constexpr uint32_t tilize_in0 = get_compile_time_arg_val(12); + + constexpr uint32_t per_core_M = get_compile_time_arg_val(13); + constexpr uint32_t per_core_N = get_compile_time_arg_val(14); + constexpr uint32_t per_core_MN = get_compile_time_arg_val(15); + constexpr uint32_t block_bhw = get_compile_time_arg_val(16); + + + constexpr uint32_t dst0 = 0; + constexpr uint32_t scaler0 = 0; + + constexpr uint32_t cb_in0 = tt::CB::c_in0; + constexpr uint32_t cb_scaler = tt::CB::c_in2; + constexpr uint32_t cb_eps = tt::CB::c_in3; + constexpr uint32_t cb_scaler_global = tt::CB::c_in4; + constexpr uint32_t cb_gamma = tt::CB::c_in5; + constexpr uint32_t cb_beta = tt::CB::c_in6; + constexpr uint32_t cb_x = tt::CB::c_intermed0; // x minus mean + constexpr uint32_t cb_xmm = tt::CB::c_intermed1; // x minus mean + constexpr uint32_t cb_ex_partial = tt::CB::dataflow0; // E[x] partial reduce + constexpr uint32_t cb_ex = tt::CB::dataflow1; // E[x] global reduce + constexpr uint32_t cb_ex_external = tt::CB::dataflow2; + constexpr uint32_t cb_xmm2 = cb_x; // xmm^2 + constexpr uint32_t cb_ex2pe = tt::CB::c_intermed3; // E[(x-E[x])^2]+eps + constexpr uint32_t cb_fusion = cb_xmm; // stream gamma/beta + constexpr uint32_t cb_out = tt::CB::c_out0; + constexpr uint32_t cb_ex_global = num_cores_per_mcast_group == 1 ? cb_ex_partial : tt::CB::dataflow7; + + uint32_t index_subblock_w_offset = 0; + uint32_t index_h_offset = 0; + uint32_t index_b_offset = 0; + uint32_t index_g_offset = 0; + uint32_t index_b_offset_ex = 0; + + constexpr int cb_in = tilize_in0 ? cb_x : cb_in0; + constexpr int cb_im = (do_gamma | do_beta) ? cb_x : cb_out; + constexpr int cb_outgamma = do_beta ? cb_fusion : cb_out; + + binary_op_init_common(cb_in0, cb_in0, cb_xmm); + + if constexpr (tilize_in0) { + tilize_in(cb_in0, cb_in, per_core_M, per_core_N); + cb_wait_front(cb_in, per_core_MN); + } + + // Partial-E[x] for each core + unpack_reconfig_data_format(cb_in, cb_scaler); + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_reserve_back(cb_ex_partial, num_batch_group); + cb_wait_front(cb_scaler, 1); + index_b_offset = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bg_offset = index_b_offset + index_g_offset; + index_h_offset = 0; + tile_regs_acquire(); + for (uint32_t h = 0; h < block_h; ++h) { + for (uint32_t w = 0; w < block_w; ++w) { + uint32_t index = index_bg_offset + index_h_offset + w; + reduce_tile(REDUCE_OP, REDUCE_DIM, cb_in, cb_scaler, index, scaler0, dst0); + } + index_h_offset += per_core_N; + } + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex_partial); + tile_regs_release(); + index_g_offset += block_w; + } + index_b_offset += block_bhw; + } + cb_push_back(cb_ex_partial, num_batch_group); + reduce_revert_delta(); + unpack_reconfig_data_format(cb_xmm, cb_xmm); + + if constexpr(is_mcast_sender and num_cores_per_mcast_group > 1) { + index_b_offset = 0; + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_reserve_back(cb_ex_global, num_batch_group); + cb_reserve_back(cb_ex, num_batch_group); + for (uint32_t bg = 0; bg < num_batch_group; ++bg) { + tile_regs_acquire(); + cb_wait_front(cb_scaler_global, 1); + for (uint32_t w = 0; w < num_cores_per_mcast_group; w++) { + cb_wait_front(cb_ex_external, 1); + reduce_tile(REDUCE_OP, REDUCE_DIM, cb_ex_external, cb_scaler_global, 0, scaler0, dst0); + cb_pop_front(cb_ex_external, 1); + } + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex_global); + tile_regs_release(); + } + reduce_revert_delta(); + cb_push_back(cb_ex_global, num_batch_group); + cb_push_back(cb_ex, num_batch_group); + } + + // x - E[x] + sub_tiles_bcast_scalar_init_short(); + cb_reserve_back(cb_xmm, per_core_MN); + cb_wait_front(cb_ex_global, num_batch_group); + unpack_reconfig_data_format(cb_in, cb_ex_global); + index_b_offset = 0; + index_b_offset_ex = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bhg_offset = index_b_offset + index_h_offset + index_g_offset; + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = index_bhg_offset + index_subblock_w_offset + w; + uint32_t index_ex = index_b_offset_ex + g; + sub_tiles_bcast_scalar(cb_in, cb_ex_global, index, index_ex, w); + } + tile_regs_commit(); + tile_regs_wait(); + for (uint32_t i = 0; i < subblock_w; i++) { + pack_tile(i, cb_xmm); + } + tile_regs_release(); + index_subblock_w_offset += subblock_w; + } + index_g_offset += block_w; + } + index_h_offset += per_core_N; + } + index_b_offset_ex += group; + index_b_offset += block_bhw; + } + cb_pop_front(cb_in, per_core_MN); + cb_pop_front(cb_ex_global, num_batch_group); + cb_push_back(cb_xmm, per_core_MN); + cb_wait_front(cb_xmm, per_core_MN); + unpack_reconfig_data_format(cb_xmm2, cb_xmm2); + + // (x - E[x])^2 + mul_tiles_init(); + cb_reserve_back(cb_xmm2, per_core_MN); + index_b_offset = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bhg_offset = index_b_offset + index_h_offset + index_g_offset; + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = index_bhg_offset + index_subblock_w_offset + w; + mul_tiles(cb_xmm, cb_xmm, index, index, w); + } + tile_regs_commit(); + tile_regs_wait(); + for (uint32_t i = 0; i < subblock_w; i++) { + pack_tile(i, cb_xmm2); + } + tile_regs_release(); + index_subblock_w_offset += subblock_w; + } + index_g_offset += block_w; + } + index_h_offset += per_core_N; + } + index_b_offset += block_bhw; + } + cb_push_back(cb_xmm2, per_core_MN); + + // Partial-Var(x) + index_b_offset = 0; + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_reserve_back(cb_ex_partial, num_batch_group); + cb_wait_front(cb_xmm2, per_core_MN); + cb_wait_front(cb_scaler, 1); + for (uint32_t b = 0; b < batch; ++b) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bg_offset = index_b_offset + index_g_offset; + index_h_offset = 0; + tile_regs_acquire(); + for (uint32_t h = 0; h < block_h; ++h) { + for (uint32_t w = 0; w < block_w; ++w) { + uint32_t index = index_bg_offset + index_h_offset + w; + reduce_tile(REDUCE_OP, REDUCE_DIM, cb_xmm2, cb_scaler, index, scaler0, dst0); + } + index_h_offset += per_core_N; + } + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex_partial); + tile_regs_release(); + index_g_offset += block_w; + } + index_b_offset += block_bhw; + } + cb_push_back(cb_ex_partial, num_batch_group); + cb_pop_front(cb_xmm2, per_core_MN); + reduce_revert_delta(); + + // global reduce + if constexpr(is_mcast_sender and num_cores_per_mcast_group > 1) { + index_b_offset = 0; + reduce_init_delta(REDUCE_OP, REDUCE_DIM); + cb_reserve_back(cb_ex_global, num_batch_group); + cb_reserve_back(cb_ex, num_batch_group); + for (uint32_t bg = 0; bg < num_batch_group; ++bg) { + tile_regs_acquire(); + cb_wait_front(cb_scaler_global, 1); + for (uint32_t w = 0; w < num_cores_per_mcast_group; w++) { + cb_wait_front(cb_ex_external, 1); + reduce_tile(REDUCE_OP, REDUCE_DIM, cb_ex_external, cb_scaler_global, 0, scaler0, dst0); + cb_pop_front(cb_ex_external, 1); + } + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex_global); + tile_regs_release(); + } + reduce_revert_delta(); + cb_push_back(cb_ex_global, num_batch_group); + cb_push_back(cb_ex, num_batch_group); + } + + cb_wait_front(cb_ex_global, num_batch_group); + cb_reserve_back(cb_ex2pe, num_batch_group); + for (uint32_t bg = 0; bg < num_batch_group; ++bg) { + // (Var + eps) + tile_regs_acquire(); + add_tiles_init(); + add_tiles(cb_ex_global, cb_eps, bg, 0, dst0); + tile_regs_wait(); + // sqrt(Var + eps) + sqrt_tile_init(); + sqrt_tile(dst0); + tile_regs_wait(); + // 1/[sqrt(Var + eps)] + recip_tile_init(); + recip_tile(dst0); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(dst0, cb_ex2pe); + tile_regs_release(); + } + cb_push_back(cb_ex2pe, num_batch_group); + cb_pop_front(cb_ex_global, num_batch_group); + + // (x - Ex) * 1/[sqrt(Var + eps)] + if constexpr(do_gamma == 0 && do_beta == 0) { + pack_reconfig_data_format(cb_out); + } + mul_tiles_bcast_scalar_init_short(); + cb_reserve_back(cb_im, per_core_MN); + cb_wait_front(cb_ex2pe, num_batch_group); + index_b_offset = 0; + index_b_offset_ex = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bhg_offset = index_b_offset + index_h_offset + index_g_offset; + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = index_bhg_offset + index_subblock_w_offset + w; + uint32_t index_ex2pe = index_b_offset_ex + g; + mul_tiles_bcast_scalar(cb_xmm, cb_ex2pe, index, index_ex2pe, w); + } + tile_regs_commit(); + tile_regs_wait(); + for (uint32_t i = 0; i < subblock_w; i++) { + pack_tile(i, cb_im); + } + tile_regs_release(); + index_subblock_w_offset += subblock_w; + } + index_g_offset += block_w; + } + index_h_offset += per_core_N; + } + index_b_offset_ex += group; + index_b_offset += block_bhw; + } + cb_push_back(cb_im, per_core_MN); + cb_pop_front(cb_ex2pe, num_batch_group); + cb_pop_front(cb_xmm, per_core_MN); + cb_wait_front(cb_im, per_core_MN); + + if constexpr(do_gamma) { + unpack_reconfig_data_format(cb_im, cb_gamma); + if constexpr(do_beta == 0) { + pack_reconfig_data_format(cb_out); + } + mul_bcast_rows_init_short(); + cb_reserve_back(cb_outgamma, per_core_MN); + cb_wait_front(cb_gamma, per_core_N); + index_b_offset = 0; + index_b_offset_ex = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bhg_offset = index_b_offset + index_h_offset + index_g_offset; + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = index_bhg_offset + index_subblock_w_offset + w; + uint32_t index_gm = index_subblock_w_offset + w + index_g_offset; + mul_tiles_bcast_rows(cb_im, cb_gamma, index, index_gm, w); + } + tile_regs_commit(); + tile_regs_wait(); + for (uint32_t i = 0; i < subblock_w; i++) { + pack_tile(i, cb_outgamma); + } + tile_regs_release(); + index_subblock_w_offset += subblock_w; + } + index_g_offset += block_w; + } + index_h_offset += per_core_N; + } + index_b_offset_ex += group; + index_b_offset += block_bhw; + } + cb_push_back(cb_outgamma, per_core_MN); + cb_pop_front(cb_im, per_core_MN); + cb_wait_front(cb_outgamma, per_core_MN); + } + + if constexpr(do_beta) { + unpack_reconfig_data_format(cb_fusion, cb_beta); + pack_reconfig_data_format(cb_out); + add_bcast_rows_init_short(); + cb_reserve_back(cb_out, per_core_MN); + cb_wait_front(cb_beta, per_core_N); + index_b_offset = 0; + index_b_offset_ex = 0; + for (uint32_t b = 0; b < batch; ++b) { + index_h_offset = 0; + for (uint32_t i = 0; i < block_h; i++) { + index_g_offset = 0; + for (uint32_t g = 0; g < group; ++g) { + uint32_t index_bhg_offset = index_b_offset + index_h_offset + index_g_offset; + index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + tile_regs_acquire(); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = index_bhg_offset + index_subblock_w_offset + w; + uint32_t index_gm = index_subblock_w_offset + w + index_g_offset; + add_tiles_bcast_rows(cb_fusion, cb_beta, index, index_gm, w); + } + tile_regs_commit(); + tile_regs_wait(); + for (uint32_t i = 0; i < subblock_w; i++) { + pack_tile(i, cb_out); + } + tile_regs_release(); + index_subblock_w_offset += subblock_w; + } + index_g_offset += block_w; + } + index_h_offset += per_core_N; + } + index_b_offset_ex += group; + index_b_offset += block_bhw; + } + cb_push_back(cb_out, per_core_MN); + cb_pop_front(cb_im, per_core_MN); + cb_wait_front(cb_out, per_core_MN); + } + +} +} diff --git a/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_gn.cpp b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_gn.cpp new file mode 100644 index 00000000000..77e2b764325 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_receiver_unary_sharded_gn.cpp @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "hostdevcommon/common_values.hpp" + +// #include "debug/dprint.h" + +// split REDUCE across cores +void kernel_main() { + constexpr uint32_t reduce_receiver_semaphore_addr = get_compile_time_arg_val(0); + constexpr uint32_t reduce_sender_semaphore_addr = get_compile_time_arg_val(1); + constexpr uint32_t num_group_batch = get_compile_time_arg_val(2); + + const uint32_t mcast_sender_noc_x = get_arg_val(0); + const uint32_t mcast_sender_noc_y = get_arg_val(1); + + constexpr uint32_t cb_ex_partial = tt::CB::dataflow0; // E[x] partial reduce + constexpr uint32_t cb_ex_global = tt::CB::dataflow7; // E[x] global reduce + + const uint32_t single_tile_size_bytes = get_tile_size(cb_ex_partial); // tile size + const DataFormat data_format = get_dataformat(cb_ex_partial); // data format + + volatile tt_l1_ptr uint32_t* reduce_receiver_semaphore_addr_ptr = reinterpret_cast(reduce_receiver_semaphore_addr); + volatile tt_l1_ptr uint32_t* reduce_sender_semaphore_addr_ptr = reinterpret_cast(reduce_sender_semaphore_addr); + + const uint64_t reduce_receiver_semaphore_noc_addr = get_noc_addr(mcast_sender_noc_x, mcast_sender_noc_y, reduce_receiver_semaphore_addr); + + for (uint32_t i=0; i < 2; ++i) { + // wait for local data ready + cb_wait_front(cb_ex_partial, num_group_batch); + noc_semaphore_set(reduce_sender_semaphore_addr_ptr, INVALID); + cb_reserve_back(cb_ex_global, num_group_batch); + noc_semaphore_inc(reduce_receiver_semaphore_noc_addr, 1); + noc_semaphore_wait(reduce_sender_semaphore_addr_ptr, VALID); + cb_push_back(cb_ex_global, num_group_batch); + cb_pop_front(cb_ex_partial, num_group_batch); + } +} diff --git a/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_sender_unary_sharded_gn.cpp b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_sender_unary_sharded_gn.cpp new file mode 100644 index 00000000000..caa424430fd --- /dev/null +++ b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/reader_mcast_sender_unary_sharded_gn.cpp @@ -0,0 +1,195 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "hostdevcommon/common_values.hpp" +// #include "debug/dprint.h" + +// split REDUCE across cores +void kernel_main() { + constexpr uint32_t reduce_receiver_semaphore_addr = get_compile_time_arg_val(0); + constexpr uint32_t reduce_sender_semaphore_addr = get_compile_time_arg_val(1); + constexpr uint32_t num_mcast_cores = get_compile_time_arg_val(2); + constexpr uint32_t num_group_batch = get_compile_time_arg_val(3); + + const bool has_mcast_first_group = get_arg_val(0); + const bool has_mcast_last_group = get_arg_val(1); + + // mid mcast group + const uint32_t mcast_dest_noc_start_x = get_arg_val(2); + const uint32_t mcast_dest_noc_start_y = get_arg_val(3); + const uint32_t mcast_dest_noc_end_x = get_arg_val(4); + const uint32_t mcast_dest_noc_end_y = get_arg_val(5); + const uint32_t num_mcast_cores_mid_group = get_arg_val(6); + + // first mcast group + uint32_t mcast_first_group_dest_noc_start_x; + uint32_t mcast_first_group_dest_noc_start_y; + uint32_t mcast_first_group_dest_noc_end_x; + uint32_t mcast_first_group_dest_noc_end_y; + // last mcast group + uint32_t mcast_last_group_dest_noc_start_x; + uint32_t mcast_last_group_dest_noc_start_y; + uint32_t mcast_last_group_dest_noc_end_x; + uint32_t mcast_last_group_dest_noc_end_y; + // volatile tt_l1_ptr uint32_t * noc_coord; + volatile tt_l1_ptr uint32_t * noc_coord_x; + volatile tt_l1_ptr uint32_t * noc_coord_y; + + // number of cores in mcast groups + uint32_t num_mcast_cores_first_group; + uint32_t num_mcast_cores_last_group; + + // noc addrs for first and last groups + uint64_t reduce_sender_first_group_semaphore_noc_addr; + uint64_t multicast_first_group_data_noc; + uint64_t reduce_sender_last_group_semaphore_noc_addr; + uint64_t multicast_last_group_data_noc; + + if (has_mcast_first_group and has_mcast_last_group) { + + mcast_first_group_dest_noc_start_x = get_arg_val(7); + mcast_first_group_dest_noc_start_y = get_arg_val(8); + mcast_first_group_dest_noc_end_x = get_arg_val(9); + mcast_first_group_dest_noc_end_y = get_arg_val(10); + num_mcast_cores_first_group = get_arg_val(11); + + mcast_last_group_dest_noc_start_x = get_arg_val(12); + mcast_last_group_dest_noc_start_y = get_arg_val(13); + mcast_last_group_dest_noc_end_x = get_arg_val(14); + mcast_last_group_dest_noc_end_y = get_arg_val(15); + num_mcast_cores_last_group = get_arg_val(16); + + noc_coord_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(17)); + noc_coord_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(17+num_mcast_cores)); + + } else if (has_mcast_first_group and not has_mcast_last_group) { + mcast_first_group_dest_noc_start_x = get_arg_val(7); + mcast_first_group_dest_noc_start_y = get_arg_val(8); + mcast_first_group_dest_noc_end_x = get_arg_val(9); + mcast_first_group_dest_noc_end_y = get_arg_val(10); + num_mcast_cores_first_group = get_arg_val(11); + + noc_coord_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(12)); + noc_coord_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(12+num_mcast_cores)); + + } else if (not has_mcast_first_group and has_mcast_last_group) { + mcast_last_group_dest_noc_start_x = get_arg_val(7); + mcast_last_group_dest_noc_start_y = get_arg_val(8); + mcast_last_group_dest_noc_end_x = get_arg_val(9); + mcast_last_group_dest_noc_end_y = get_arg_val(10); + num_mcast_cores_last_group = get_arg_val(11); + + noc_coord_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(12)); + noc_coord_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(12+num_mcast_cores)); + + } else { + noc_coord_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7)); + noc_coord_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7+num_mcast_cores)); + } + + const uint64_t reduce_sender_semaphore_noc_addr = get_noc_multicast_addr( + mcast_dest_noc_start_x, + mcast_dest_noc_start_y, + mcast_dest_noc_end_x, + mcast_dest_noc_end_y, + reduce_sender_semaphore_addr); + + const uint64_t multicast_data_noc = get_noc_multicast_addr( + mcast_dest_noc_start_x, + mcast_dest_noc_start_y, + mcast_dest_noc_end_x, + mcast_dest_noc_end_y, + 0); + + if (has_mcast_first_group) { + reduce_sender_first_group_semaphore_noc_addr = get_noc_multicast_addr( + mcast_first_group_dest_noc_start_x, + mcast_first_group_dest_noc_start_y, + mcast_first_group_dest_noc_end_x, + mcast_first_group_dest_noc_end_y, + reduce_sender_semaphore_addr); + + multicast_first_group_data_noc = get_noc_multicast_addr( + mcast_first_group_dest_noc_start_x, + mcast_first_group_dest_noc_start_y, + mcast_first_group_dest_noc_end_x, + mcast_first_group_dest_noc_end_y, + 0); + } + if (has_mcast_last_group) { + reduce_sender_last_group_semaphore_noc_addr = get_noc_multicast_addr( + mcast_last_group_dest_noc_start_x, + mcast_last_group_dest_noc_start_y, + mcast_last_group_dest_noc_end_x, + mcast_last_group_dest_noc_end_y, + reduce_sender_semaphore_addr); + + multicast_last_group_data_noc = get_noc_multicast_addr( + mcast_last_group_dest_noc_start_x, + mcast_last_group_dest_noc_start_y, + mcast_last_group_dest_noc_end_x, + mcast_last_group_dest_noc_end_y, + 0); + } + + volatile tt_l1_ptr uint32_t* reduce_sender_semaphore_addr_ptr = reinterpret_cast(reduce_sender_semaphore_addr); + *reduce_sender_semaphore_addr_ptr = VALID; + volatile tt_l1_ptr uint32_t* reduce_receiver_semaphore_addr_ptr = reinterpret_cast(reduce_receiver_semaphore_addr); + + constexpr uint32_t cb_ex_partial = tt::CB::dataflow0; + constexpr uint32_t cb_ex = tt::CB::dataflow1; + constexpr uint32_t cb_ex_external = tt::CB::dataflow2; + + const uint32_t single_tile_size_bytes = get_tile_size(cb_ex_partial); + const DataFormat data_format = get_dataformat(cb_ex_partial); + + if constexpr(num_mcast_cores > 1) { + for (uint32_t i=0; i < 2; ++i) { + // wait for local data ready + cb_wait_front(cb_ex_partial, num_group_batch); + + // wait for all other cores data ready + noc_semaphore_wait(reduce_receiver_semaphore_addr_ptr, num_mcast_cores-1); + noc_semaphore_set(reduce_receiver_semaphore_addr_ptr, 0); + + // read data from other cores + uint32_t l1_read_addr_ex_par = get_read_ptr(cb_ex_partial); + for (uint32_t bg = 0; bg < num_group_batch; ++bg) { + for(uint32_t i = 0; i < num_mcast_cores; ++i) { + uint32_t l1_write_addr_external = get_write_ptr(cb_ex_external); + cb_reserve_back(cb_ex_external, 1); + uint64_t noc_addr_ex_par = get_noc_addr(noc_coord_x[i], noc_coord_y[i], l1_read_addr_ex_par); + noc_async_read_one_packet(noc_addr_ex_par, l1_write_addr_external, single_tile_size_bytes); + noc_async_read_barrier(); + cb_push_back(cb_ex_external, 1); + } + l1_read_addr_ex_par += single_tile_size_bytes; + } + + // wait for global reduce done + cb_wait_front(cb_ex, num_group_batch); + cb_pop_front(cb_ex_partial, num_group_batch); + + // mcast to other cores + uint32_t l1_read_addr_ex = get_read_ptr(cb_ex); + noc_async_write_multicast(l1_read_addr_ex, multicast_data_noc | l1_read_addr_ex, num_group_batch * single_tile_size_bytes, num_mcast_cores_mid_group, true); + noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_semaphore_noc_addr, num_mcast_cores_mid_group, false); + + if (has_mcast_first_group) { + noc_async_write_multicast(l1_read_addr_ex, multicast_first_group_data_noc | l1_read_addr_ex, num_group_batch * single_tile_size_bytes, num_mcast_cores_first_group, true); + noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_first_group_semaphore_noc_addr, num_mcast_cores_first_group, false); + } + + if (has_mcast_last_group) { + noc_async_write_multicast(l1_read_addr_ex, multicast_last_group_data_noc | l1_read_addr_ex, num_group_batch * single_tile_size_bytes, num_mcast_cores_last_group, true); + noc_semaphore_set_multicast(reduce_sender_semaphore_addr, reduce_sender_last_group_semaphore_noc_addr, num_mcast_cores_last_group, false); + } + noc_async_write_barrier(); + cb_pop_front(cb_ex, num_group_batch); + + } + } +} diff --git a/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn.cpp b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn.cpp new file mode 100644 index 00000000000..d827a84b80e --- /dev/null +++ b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn.cpp @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "hostdevcommon/common_values.hpp" + +FORCE_INLINE void generate_bcast_scaler_w() { + constexpr uint32_t cb_in_2 = tt::CB::c_in2; + union { float f; uint32_t u; } u; u.u = get_arg_val(1); + cb_reserve_back(cb_in_2, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_in_2)); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + j] = uint16_t(u.u>>16); + cb_push_back(cb_in_2, 1); +} + +FORCE_INLINE void generate_bcast_scaler_c() { + constexpr uint32_t cb_in_4 = tt::CB::c_in4; + union { float f; uint32_t u; } u; u.u = get_arg_val(0); + cb_reserve_back(cb_in_4, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_in_4)); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + j] = uint16_t(u.u>>16); + cb_push_back(cb_in_4, 1); +} + +FORCE_INLINE void generate_epsilon() { + constexpr uint32_t eps_cb_id = tt::CB::c_in3; + union { float f; uint32_t u; } u; u.u = get_arg_val(2); + cb_reserve_back(eps_cb_id, 1); + auto ptr = reinterpret_cast(get_write_ptr(eps_cb_id)); + + for (int k = 0; k < 4; k+=2) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + (j << 4)] = uint16_t(u.u>>16); + cb_push_back(eps_cb_id, 1); +} + +void kernel_main() { + constexpr bool is_mcast_sender = get_compile_time_arg_val(0) == 1; + constexpr bool fuse_gamma = get_compile_time_arg_val(1) == 1; + constexpr bool fuse_beta = get_compile_time_arg_val(2) == 1; + constexpr bool gamma_is_dram = get_compile_time_arg_val(3) == 1; + constexpr bool beta_is_dram = get_compile_time_arg_val(4) == 1; + constexpr uint32_t block_w = get_compile_time_arg_val(5); + + const uint32_t gamma_addr = get_arg_val(3); + const uint32_t beta_addr = get_arg_val(4); + const uint32_t gamma_tile_start_id = get_arg_val(5); + const uint32_t beta_tile_start_id = get_arg_val(6); + + constexpr uint32_t cb_gamma = tt::CB::c_in5; + constexpr uint32_t cb_beta = tt::CB::c_in6; + + // constexpr uint32_t block_w = 4; + const uint32_t single_tile_size_bytes = get_tile_size(cb_gamma); + + generate_bcast_scaler_w(); + if constexpr(is_mcast_sender) { + generate_bcast_scaler_c(); + } + generate_epsilon(); + + if constexpr(fuse_gamma) { + const uint32_t gamma_tile_bytes = get_tile_size(cb_gamma); + const DataFormat gamma_data_format = get_dataformat(cb_gamma); + const InterleavedAddrGenFast gamma = { + .bank_base_address = gamma_addr, + .page_size = gamma_tile_bytes, + .data_format = gamma_data_format + }; + + uint32_t l1_write_addr_gamma = get_write_ptr(cb_gamma); + cb_reserve_back(cb_gamma, block_w); + for (uint32_t w = 0; w < block_w; w++) { + uint32_t tile_id = gamma_tile_start_id + w; + noc_async_read_tile(tile_id, gamma, l1_write_addr_gamma); + l1_write_addr_gamma += gamma_tile_bytes; + } + noc_async_read_barrier(); + cb_push_back(cb_gamma, block_w); + } + + if constexpr(fuse_beta) { + const uint32_t beta_tile_bytes = get_tile_size(cb_beta); + const DataFormat beta_data_format = get_dataformat(cb_beta); + const InterleavedAddrGenFast beta = { + .bank_base_address = beta_addr, + .page_size = beta_tile_bytes, + .data_format = beta_data_format + }; + + uint32_t l1_write_addr_beta = get_write_ptr(cb_beta); + cb_reserve_back(cb_beta, block_w); + for (uint32_t w = 0; w < block_w; w++) { + uint32_t tile_id = beta_tile_start_id + w; + noc_async_read_tile(tile_id, beta, l1_write_addr_beta); + l1_write_addr_beta += beta_tile_bytes; + } + noc_async_read_barrier(); + cb_push_back(cb_beta, block_w); + } +} diff --git a/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn_rm_gb.cpp b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn_rm_gb.cpp new file mode 100644 index 00000000000..a7b3f0443f5 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/groupnorm/kernels/dataflow/writer_unary_sharded_gn_rm_gb.cpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "hostdevcommon/common_values.hpp" + + +FORCE_INLINE void generate_bcast_scaler_c() { + constexpr uint32_t cb_in_4 = tt::CB::c_in4; + union { float f; uint32_t u; } u; u.u = get_arg_val(0); + cb_reserve_back(cb_in_4, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_in_4)); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + j] = uint16_t(u.u>>16); + cb_push_back(cb_in_4, 1); +} + +FORCE_INLINE void generate_bcast_scaler_w() { + constexpr uint32_t cb_in_2 = tt::CB::c_in2; + union { float f; uint32_t u; } u; u.u = get_arg_val(1); + cb_reserve_back(cb_in_2, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_in_2)); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + j] = uint16_t(u.u>>16); + cb_push_back(cb_in_2, 1); +} + +FORCE_INLINE void generate_epsilon() { + constexpr uint32_t eps_cb_id = tt::CB::c_in3; + union { float f; uint32_t u; } u; u.u = get_arg_val(2); + cb_reserve_back(eps_cb_id, 1); + auto ptr = reinterpret_cast(get_write_ptr(eps_cb_id)); + + for (int k = 0; k < 4; k+=2) + for (int j = 0; j < 16; j++) + ptr[(k << 8) + (j << 4)] = uint16_t(u.u>>16); + cb_push_back(eps_cb_id, 1); +} + +void kernel_main() { + constexpr bool is_mcast_sender = get_compile_time_arg_val(0) == 1; + constexpr bool fuse_gamma = get_compile_time_arg_val(1) == 1; + constexpr bool fuse_beta = get_compile_time_arg_val(2) == 1; + constexpr bool gamma_is_dram = get_compile_time_arg_val(3) == 1; + constexpr bool beta_is_dram = get_compile_time_arg_val(4) == 1; + constexpr uint32_t block_w = get_compile_time_arg_val(5); + + const uint32_t gamma_addr = get_arg_val(3); + const uint32_t beta_addr = get_arg_val(4); + const uint32_t gamma_tile_start_id = get_arg_val(5); + const uint32_t beta_tile_start_id = get_arg_val(6); + + constexpr uint32_t cb_gamma = tt::CB::c_in5; + constexpr uint32_t cb_beta = tt::CB::c_in6; + + // constexpr uint32_t block_w = 4; + const uint32_t single_tile_size_bytes = get_tile_size(cb_gamma); + + generate_bcast_scaler_w(); + if constexpr(is_mcast_sender) { + generate_bcast_scaler_c(); + } + generate_epsilon(); + + #define stick_size_is_pow2 get_compile_time_arg_val(6) == 1 + #if (stick_size_is_pow2) + constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(7); + #else + constexpr uint32_t page_size = get_compile_time_arg_val(7); + #endif + + if constexpr(fuse_gamma) { + const uint32_t gamma_tile_bytes = get_tile_size(cb_gamma); + #if (stick_size_is_pow2) + const InterleavedPow2AddrGen gamma = { + .bank_base_address = gamma_addr, + .log_base_2_of_page_size = log_base_2_of_page_size + }; + #else + const InterleavedAddrGen gamma = { + .bank_base_address = gamma_addr, + .page_size = page_size + }; + #endif + + uint32_t l1_write_addr_gamma = get_write_ptr(cb_gamma); + cb_reserve_back(cb_gamma, block_w); + for (uint32_t w = 0; w < block_w; w++) { + uint32_t tile_id = gamma_tile_start_id + w; + uint64_t gamma_noc_addr = get_noc_addr(tile_id, gamma); + noc_async_read(gamma_noc_addr, l1_write_addr_gamma, 32); + gamma_noc_addr += 32; + noc_async_read(gamma_noc_addr, l1_write_addr_gamma + 512, 32); + l1_write_addr_gamma += gamma_tile_bytes; + } + noc_async_read_barrier(); + cb_push_back(cb_gamma, block_w); + } + + if constexpr(fuse_beta) { + const uint32_t beta_tile_bytes = get_tile_size(cb_beta); + #if (stick_size_is_pow2) + const InterleavedPow2AddrGen beta = { + .bank_base_address = beta_addr, + .log_base_2_of_page_size = log_base_2_of_page_size + }; + #else + const InterleavedAddrGen beta = { + .bank_base_address = beta_addr, + .page_size = page_size + }; + #endif + + uint32_t l1_write_addr_beta = get_write_ptr(cb_beta); + cb_reserve_back(cb_beta, block_w); + for (uint32_t w = 0; w < block_w; w++) { + uint32_t tile_id = beta_tile_start_id + w; + uint64_t beta_noc_addr = get_noc_addr(tile_id, beta); + noc_async_read(beta_noc_addr, l1_write_addr_beta, 32); + beta_noc_addr += 32; + noc_async_read(beta_noc_addr, l1_write_addr_beta + 512, 32); + l1_write_addr_beta += beta_tile_bytes; + } + noc_async_read_barrier(); + cb_push_back(cb_beta, block_w); + } +} diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index 620af472e1c..8fa397f4c49 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -23,6 +23,7 @@ #include "tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" #include "tt_dnn/op_library/moreh_arange/moreh_arange_op.hpp" #include "tt_dnn/op_library/moreh_sgd/moreh_sgd_op.hpp" +#include "tt_dnn/op_library/groupnorm/groupnorm_op.hpp" namespace py = pybind11; @@ -583,6 +584,32 @@ void py_module(py::module& m_primary) { py::arg("nesterov").noconvert(), py::arg("momentum_initialized").noconvert(), "Performs a SGD operation."); + + py::class_(m_primary, "GroupNormShardedMultiCoreProgramConfig") + .def( + py::init(), + py::kw_only(), + py::arg("compute_with_storage_grid_size"), + py::arg("math_fidelity").noconvert() = MathFidelity::HiFi4, + py::arg("im_data_format").noconvert() = DataType::BFLOAT16, + py::arg("out_data_format").noconvert() = DataType::BFLOAT16, + py::arg("inplace").noconvert() = false + ); + + m_primary.def( + "groupnorm", + &groupnorm, + py::arg("input").noconvert(), + py::arg("num_groups").noconvert(), + py::arg("batch").noconvert(), + py::arg("eps").noconvert(), + py::arg("gamma").noconvert() = std::nullopt, + py::arg("beta").noconvert() = std::nullopt, + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("program_config").noconvert() = GroupNormShardedMultiCoreProgramConfig{}, + R"doc( + Performs a groupnorm operation, returna a output tensor the same shape as input. + )doc"); } } // namespace