diff --git a/tests/ttnn/unit_tests/operations/test_conv.py b/tests/ttnn/unit_tests/operations/test_conv.py index 93f8c9bbfa7..25a91e44ab6 100644 --- a/tests/ttnn/unit_tests/operations/test_conv.py +++ b/tests/ttnn/unit_tests/operations/test_conv.py @@ -11,66 +11,7 @@ import ttnn -@skip_for_wormhole_b0() -@pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array", - ( - # unique convs in rn50 (complete list) - # first conv post folding and input_channels padding to tile width - (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True), - # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True), - # rn50 layer2 - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True), - (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True), - # rn50 layer3 - (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False), - (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False), - # rn50 layer4 - (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False), - (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False), - # sd convs with HxW=32x32 - # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, False), - # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, False), - # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False), # bfloat16 activations doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, False), #fails to parallelize with sharding - # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, False), #fails to parallelize with sharding - # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit - # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False), # doesnt fit at all.. for all data types - # sd conv with HxW=512x512 - # (1, 320, 320, 512, 512, 3, 3, 1, 1, 1, 1, False), # doesnt fit at all.. for all data types - # sd conv with HxW=256x256 - # (1, 320, 320, 256, 256, 3, 3, 1, 1, 1, 1, False), # doesnt fit at all.. for all data types - # sd convs with HxW=64x64 - # (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False), # bfloat16 weights or activations doesnt fit - (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False), - # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False), # doesnt fit at all.. for all datatypes - # (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False), # bfloat16 weights or activations doesnt fit - # (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False), # bfloat16 activations doesnt fit - # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit - # (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False), #slightly low pcc 0.99697. bfloat16 doesnt fit. - # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit - # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False), # not tested yet - # (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False), # not tested yet - ), -) -@pytest.mark.parametrize( - "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], - ids=["weights_BFLOAT16", "weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], - ids=["activations_BFLOAT16", "activations_BFLOAT8_B"], -) -@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi], ids=["HiFi4", "LoFi"]) -def test_conv( - use_program_cache, +def run_conv( device, math_fidelity, activations_dtype, @@ -87,28 +28,8 @@ def test_conv( pad_h, pad_w, use_1d_systolic_array, + config_override, ): - if input_channels == 16: - pytest.skip("These tests are hanging in interleaved_to_sharded after rebase. Issue: #4336") - - if math_fidelity != ttnn.MathFidelity.LoFi: - pytest.skip( - "By default, only run tests with LoFi math for pipelines. For local unit testing, enable the other variants by uncommenting the skip here!" - ) - - if ( - activations_dtype == ttnn.bfloat16 - and batch_size == 20 - and ( - output_channels == 64 - or ( - stride_h == 2 - and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) - ) - ) - ): - pytest.skip("Skipping test because it won't fit in L1!") - torch.manual_seed(0) conv_input_shape = [batch_size, input_channels, input_height, input_width] conv_weight_shape = [output_channels, input_channels, filter_height, filter_width] @@ -150,6 +71,7 @@ def test_conv( bias=tt_bias_tensor, math_fidelity=math_fidelity, weights_dtype=weights_dtype, + conv_blocking_and_parallelization_config_override=config_override, ) assert "conv" in reader_patterns_cache and "halo" in reader_patterns_cache @@ -169,7 +91,207 @@ def test_conv( torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2)) if math_fidelity == ttnn.MathFidelity.LoFi and activations_dtype == ttnn.bfloat8_b: - pcc = 0.998 + pcc = 0.9969 else: pcc = 0.998 assert_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=pcc) + + +@skip_for_wormhole_b0() +@pytest.mark.parametrize( + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array", + ( + # unique convs in rn50 (complete list) + # first conv post folding and input_channels padding to tile width + (64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True), + # rn50 layer1 + (64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True), + # rn50 layer2 + (128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True), + (128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True), + (128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True), + # rn50 layer3 + (256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False), + (256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False), + # rn50 layer4 + (512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False), + (512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False), + ), +) +@pytest.mark.parametrize( + "batch_size", + [8, 16, 20], + ids=["batch_size_8", "batch_size_16", "batch_size_20"], +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16, ttnn.bfloat8_b], + ids=["weights_BFLOAT16", "weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat16, ttnn.bfloat8_b], + ids=["activations_BFLOAT16", "activations_BFLOAT8_B"], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi], ids=["HiFi4", "LoFi"]) +def test_resnet50_conv( + use_program_cache, + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, +): + if input_channels == 16: + pytest.skip("These tests are hanging in interleaved_to_sharded after rebase. Issue: #4336") + + if math_fidelity != ttnn.MathFidelity.LoFi: + pytest.skip( + "By default, only run tests with LoFi math for pipelines. For local unit testing, enable the other variants by uncommenting the skip here!" + ) + + if ( + activations_dtype == ttnn.bfloat16 + and batch_size == 20 + and ( + output_channels == 64 + or ( + stride_h == 2 + and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) + ) + ) + ): + pytest.skip("Skipping test because it won't fit in L1!") + + if ( + input_channels >= 320 + and (not input_channels == 512) + and (activations_dtype == ttnn.bfloat16 or weights_dtype == ttnn.bfloat16) + ): + pytest.skip("Skipping tests with bfloat16 for sd convs") + + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override=None, + ) + + +@skip_for_wormhole_b0() +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + ( + # sd convs with HxW=32x32 + # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, False, None), + # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, False, None), + # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 activations doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, False, None), #fails to parallelize with sharding + # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, False, None), #fails to parallelize with sharding + # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit + # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, None), # doesnt fit at all.. for all data types + # sd convs with HxW=64x64 with batch size = 1 + # (2, 32, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), # not supported + (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit + (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), + (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # + (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit + (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit + (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit. + (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit + (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, None), + (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), + # sd convs with HxW=64x64 with batch size=2 + (2, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), + (2, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), # fits with bfloat8_b + (2, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), + (2, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit + (2, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), + (2, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), # bfloat16 doesnt fit + (2, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat8_b], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat8_b], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi], ids=["HiFi4", "LoFi"]) +def test_sd_conv( + use_program_cache, + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, +): + if math_fidelity != ttnn.MathFidelity.LoFi: + pytest.skip( + "By default, only run tests with LoFi math for pipelines. For local unit testing, enable the other variants by uncommenting the skip here!" + ) + + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + ) diff --git a/tt_eager/tt_dnn/op_library/conv/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/tt_eager/tt_dnn/op_library/conv/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index c0c85564ba2..5418c52e35d 100644 --- a/tt_eager/tt_dnn/op_library/conv/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/tt_eager/tt_dnn/op_library/conv/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -173,22 +173,23 @@ void MAIN { #ifdef SFPU_OP_INIT_ACTIVATION SFPU_OP_INIT_ACTIVATION #endif + // in1 num blocks w is the outer loop. Output blocks are computed in col major order. + for(uint32_t in1_block_w_i = 0; in1_block_w_i < in1_num_blocks_w; ++in1_block_w_i) { - #ifdef PRE_TILIZE - unpack_reconfig_data_format_srca(in1_cb_id, in0_pretilize_cb_id); + for(uint32_t in0_block_h_i = 0; in0_block_h_i < in0_num_blocks_h; ++in0_block_h_i) { - col_major_to_row_major_init(); - tilize_in(in0_pretilize_cb_id, in0_subblock_h, in0_block_w, in0_num_subblocks, tilized_in0_cb_id); - row_major_to_col_major_init(); + #ifdef PRE_TILIZE + unpack_reconfig_data_format_srca(in1_cb_id, in0_pretilize_cb_id); - // TODO: unpack_reconfig_data_format_srca(in0_pretilize_cb_id, in1_cb_id) doesn't work if in0 is BFLOATB_B and in1 is BFLOAT16 - mm_block_init_short(); - unpack_reconfig_data_format_srca(in1_cb_id); - #endif + col_major_to_row_major_init(); + tilize_in(in0_pretilize_cb_id, in0_subblock_h, in0_block_w, in0_num_subblocks, tilized_in0_cb_id); + row_major_to_col_major_init(); + + // TODO: unpack_reconfig_data_format_srca(in0_pretilize_cb_id, in1_cb_id) doesn't work if in0 is BFLOATB_B and in1 is BFLOAT16 + mm_block_init_short(); + unpack_reconfig_data_format_srca(in1_cb_id); + #endif - // in1 num blocks w is the outer loop. Output blocks are computed in col major order. - for(uint32_t in1_block_w_i = 0; in1_block_w_i < in1_num_blocks_w; ++in1_block_w_i) { - for(uint32_t in0_block_h_i = 0; in0_block_h_i < in0_num_blocks_h; ++in0_block_h_i) { bool enable_reload = false; #ifdef PACK_RELU @@ -296,6 +297,7 @@ void MAIN { PACK( cb_interface[matmul_partials_cb].fifo_wr_ptr = partials_cb_write_ptr ); } } + cb_pop_front(mm_in0_cb_id, in0_block_num_tiles); cb_pop_front(in1_cb_id, in1_block_num_tiles); } // for in0_num_blocks_w @@ -380,6 +382,7 @@ void MAIN { } } #endif + } // for in0_num_blocks_h #ifdef FUSE_BIAS bias_block_offset += in1_block_w; diff --git a/tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp b/tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp index 139693e73d4..d3e6c23b015 100644 --- a/tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp +++ b/tt_eager/tt_dnn/op_library/conv/kernels/reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp @@ -27,6 +27,7 @@ void kernel_main() { uint32_t conv_act_size_h = get_arg_val(i); i+=1; uint32_t weight_size_h = get_arg_val(i); i+=1; uint32_t weight_size_w = get_arg_val(i); i+=1; + uint32_t act_num_blocks_h = get_arg_val(i); i+=1; // uint32_t act_block_h_datums = get_arg_val(i); i+=1; i+=1; // skip an arg uint32_t act_block_num_tiles = get_arg_val(i); i+=1; @@ -172,67 +173,69 @@ void kernel_main() { // Reset reader_idx to finish act_block_h_datums reader_idx = 0; - cb_reserve_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); - uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_row_major_bfloat16); - - constexpr uint32_t stride_h_bytes = (conv_act_size_w+2) * conv_act_c_read_bytes; - static_assert(act_block_h_datums % 2 == 0); // need to be even to read 2 in the body, due to packing of 2 indices in 1 uint32_t word - // #pragma GCC unroll 4 // didn't seem to help (neutral), manual unroll 2x perf drop - for (uint32_t bh = 0; bh < act_block_h_datums/2; bh++) { - uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx]; - read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices & 0xffff, conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); - read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices >> 16 , conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); - - reader_idx++; - } - // incrementing num issued in one shot is actually slower - // noc_async_read_inc_num_issued(num_issued_reads_per_block); // "false" on read - noc_async_read_barrier(); - cb_push_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); - - // compute tilizes and pops cb_id_act and pushes to tilized_in0_cb_id - cb_wait_front(tilized_in0_cb_id, act_block_num_tiles); - - - // Round robin self-mcast and receive tilized act matrix in cb_id_act - // Compute should function like regular mm - for (uint32_t act_w_outer_i = 0; act_w_outer_i < act_w_num_outer; act_w_outer_i++) { - if (act_w_outer_i == act_mcast_sender_id) { - // MCAST SENDER: send entire tilized input to other cores in column - cb_reserve_back(cb_id_act, act_block_num_tiles); - - // wait until all act mcast destinations have atomically incremented the act semaphore_addr (i.e. its value should be act_mcast_num_dests), then reset - // the semaphore_addr value back to zero for the next block - noc_semaphore_wait(act_mcast_sender_semaphore_addr_ptr, act_mcast_num_dests); - noc_semaphore_set(act_mcast_sender_semaphore_addr_ptr, 0); - - // Now we have the block in the CB address, we can mcast to dests! - uint32_t tilized_act_start_address = get_read_ptr(tilized_in0_cb_id); - uint64_t act_multicast_data_addr = act_multicast_noc_addr | get_write_ptr(cb_id_act); - // num_dests will source, since we are copying to a different local CB as well - noc_async_write_multicast_loopback_src(tilized_act_start_address, act_multicast_data_addr, act_mcast_sender_size_bytes, act_mcast_num_cores + 1); - - // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf - // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). - - // We should also multicast VALID flag to destinations for receiver semaphore - noc_semaphore_set_multicast(act_mcast_receiver_semaphore_addr, act_mcast_receiver_semaphore_noc_addr, act_mcast_num_cores); - - noc_async_write_barrier(); - } else { - // MCAST RECEIVER: receive entire tilized input from sender core - cb_reserve_back(cb_id_act, act_block_num_tiles); - - // Set act semaphore value to INVALID - noc_semaphore_set(act_mcast_receiver_semaphore_addr_ptr, INVALID); - - // Atomic increment source core counter - uint64_t act_mcast_sender_semaphore_noc_addr = get_noc_addr(act_mcast_sender_noc_x, act_mcast_sender_noc_y[act_w_outer_i], act_mcast_sender_semaphore_addr); - noc_semaphore_inc(act_mcast_sender_semaphore_noc_addr, 1); - - // wait on act semaphore value to become VALID (set by mcast sender after it multicasts data) - noc_semaphore_wait(act_mcast_receiver_semaphore_addr_ptr, VALID); + for (uint32_t nbh = 0; nbh < act_num_blocks_h; nbh++) { + cb_reserve_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); + uint32_t l1_write_addr_act = get_write_ptr(cb_id_act_row_major_bfloat16); + + constexpr uint32_t stride_h_bytes = (conv_act_size_w+2) * conv_act_c_read_bytes; + static_assert(act_block_h_datums % 2 == 0); // need to be even to read 2 in the body, due to packing of 2 indices in 1 uint32_t word + // #pragma GCC unroll 4 // didn't seem to help (neutral), manual unroll 2x perf drop + for (uint32_t bh = 0; bh < act_block_h_datums/2; bh++) { + uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx]; + read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices & 0xffff, conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); + read_channels(l1_write_addr_act, act_l1_read_addr, two_reader_indices >> 16 , conv_act_c_read_bytes, coalesced_read_bytes, stride_h_bytes); + + reader_idx++; } - cb_push_back(cb_id_act, act_block_num_tiles); + // incrementing num issued in one shot is actually slower + // noc_async_read_inc_num_issued(num_issued_reads_per_block); // "false" on read + noc_async_read_barrier(); + cb_push_back(cb_id_act_row_major_bfloat16, act_block_num_tiles); + // compute tilizes and pops cb_id_act and pushes to tilized_in0_cb_id + cb_wait_front(tilized_in0_cb_id, act_block_num_tiles); + + // Round robin self-mcast and receive tilized act matrix in cb_id_act + // Compute should function like regular mm + for (uint32_t act_w_outer_i = 0; act_w_outer_i < act_w_num_outer; act_w_outer_i++) { + if (act_w_outer_i == act_mcast_sender_id) { + // MCAST SENDER: send entire tilized input to other cores in column + cb_reserve_back(cb_id_act, act_block_num_tiles); + + // wait until all act mcast destinations have atomically incremented the act semaphore_addr (i.e. its value should be act_mcast_num_dests), then reset + // the semaphore_addr value back to zero for the next block + noc_semaphore_wait(act_mcast_sender_semaphore_addr_ptr, act_mcast_num_dests); + noc_semaphore_set(act_mcast_sender_semaphore_addr_ptr, 0); + + // Now we have the block in the CB address, we can mcast to dests! + uint32_t tilized_act_start_address = get_read_ptr(tilized_in0_cb_id); + + uint64_t act_multicast_data_addr = act_multicast_noc_addr | get_write_ptr(cb_id_act); + // num_dests will source, since we are copying to a different local CB as well + noc_async_write_multicast_loopback_src(tilized_act_start_address, act_multicast_data_addr, act_mcast_sender_size_bytes, act_mcast_num_cores + 1); + + // Note: no need for write barrier, since these two multicasts are done on the same noc id, same vc, same cmd_buf + // Also, this only works because we are setting VCs statically (using NOC_CMD_STATIC_VC). + + // We should also multicast VALID flag to destinations for receiver semaphore + noc_semaphore_set_multicast(act_mcast_receiver_semaphore_addr, act_mcast_receiver_semaphore_noc_addr, act_mcast_num_cores); + + noc_async_write_barrier(); + } else { + // MCAST RECEIVER: receive entire tilized input from sender core + cb_reserve_back(cb_id_act, act_block_num_tiles); + + // Set act semaphore value to INVALID + noc_semaphore_set(act_mcast_receiver_semaphore_addr_ptr, INVALID); + + // Atomic increment source core counter + uint64_t act_mcast_sender_semaphore_noc_addr = get_noc_addr(act_mcast_sender_noc_x, act_mcast_sender_noc_y[act_w_outer_i], act_mcast_sender_semaphore_addr); + noc_semaphore_inc(act_mcast_sender_semaphore_noc_addr, 1); + + // wait on act semaphore value to become VALID (set by mcast sender after it multicasts data) + noc_semaphore_wait(act_mcast_receiver_semaphore_addr_ptr, VALID); + } + cb_push_back(cb_id_act, act_block_num_tiles); + } // act_w_num_outer + cb_pop_front(tilized_in0_cb_id, act_block_num_tiles); } } diff --git a/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp b/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp index 75368cc9319..e7d531150fe 100644 --- a/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp +++ b/tt_eager/tt_dnn/op_library/conv/multi_core_optimized_conv_sharded/optimized_conv_op_sharded_v2.cpp @@ -551,10 +551,11 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_(const Tens if (fully_buffer_weights) { num_weight_cb_tiles *= window_outer; - } else if (per_core_weight_matrix_width_ntiles < 8) { + } else if (per_core_weight_matrix_width_ntiles < 5 && per_core_out_matrix_height_ntiles < 22) { num_weight_cb_tiles = num_weight_cb_tiles * 2; } - if (conv_act_size_c / conv_act_c_blocks < 256) { + + if (conv_act_size_c / conv_act_c_blocks < 160 && per_core_out_matrix_height_ntiles < 22) { num_act_cb_tiles = num_act_cb_tiles * 2; // double buffered } @@ -876,7 +877,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_(const Tens conv_act_size_h, weight_size_h, weight_size_w, - + num_blocks_act_h_per_core, act_block_h_datums, in0_block_num_tiles, conv_act_c_blocks,