Skip to content

Commit

Permalink
#8407: Remove 1x1 matmul fallback on convolution and generalize convo…
Browse files Browse the repository at this point in the history
…lution kernel
  • Loading branch information
tapspatel committed Jun 3, 2024
1 parent 0a220d9 commit 54bc6fd
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 139 deletions.
9 changes: 9 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@ def test_resnet50_conv_gs(
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 4, "grid_size": (2, 4)}),
# (1, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, False, None), sliding_window_op_infra/sliding_window.cpp:341: indices_length_last_core <= indices_length_per_core
(8, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, False, None),
# r50 1x1s2 shapes
(20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 first bottleneck downsample shape
(20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, True, None), # r50 first bottleneck downsample shape
(20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 second bottleneck downsample shape
# (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit
(20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, False, None), # r50 third bottleneck downsample shape
# (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit
(20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, False, None), # r50 fourth bottleneck downsample shape
# (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit
),
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,20 @@ void kernel_main() {

constexpr bool act_in_dram = get_compile_time_arg_val(0) == 1;
constexpr uint32_t stride_h = get_compile_time_arg_val(1);
constexpr uint32_t stride_w = get_compile_time_arg_val(2);
constexpr uint32_t conv_act_size_w = get_compile_time_arg_val(3);
constexpr uint32_t conv_output_w_last_index = get_compile_time_arg_val(4) - 1;
constexpr uint32_t conv_act_c_read_bytes = get_compile_time_arg_val(5);
// need to have these as compile-time since we unroll loops based on them
constexpr uint32_t window_outer = get_compile_time_arg_val(6);
constexpr uint32_t window_inner = get_compile_time_arg_val(7);
constexpr uint32_t act_block_h_datums = get_compile_time_arg_val(8);

constexpr uint32_t weight_size_w = get_compile_time_arg_val(10);
constexpr uint32_t act_num_blocks_h = get_compile_time_arg_val(14);
constexpr uint32_t act_block_num_tiles = get_compile_time_arg_val(15);
constexpr uint32_t act_w_num_outer = get_compile_time_arg_val(16);

constexpr uint32_t act_mcast_num_dests = get_compile_time_arg_val(17);
constexpr uint32_t act_mcast_num_cores = get_compile_time_arg_val(18);
constexpr uint32_t act_mcast_sender_semaphore_addr = get_compile_time_arg_val(19);
constexpr uint32_t act_mcast_receiver_semaphore_addr = get_compile_time_arg_val(20);
constexpr uint32_t act_mcast_sender_size_bytes = get_compile_time_arg_val(21);
constexpr uint32_t pad_w = get_compile_time_arg_val(22);

constexpr bool transpose_mcast = get_compile_time_arg_val(22) == 1;

Expand Down Expand Up @@ -114,8 +110,7 @@ void kernel_main() {

// TODO: need to make the read coalescing optimization cleaner
// currently works for the case of num_coalesced_reads == weight_size_w since these reads are contiguous on both src/dst side
constexpr uint32_t num_coalesced_reads = 3;
constexpr uint32_t coalesced_read_bytes = num_coalesced_reads * conv_act_c_read_bytes;
constexpr uint32_t coalesced_read_bytes = weight_size_w * conv_act_c_read_bytes;


// Fully create act matrix and tilize it before mcast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,138 +84,61 @@ void kernel_main() {
// the conditional selecting between coalescing and no-colescing must be constexpr to that compiler can optimized the other path away
// this has shown to be a big perf win
static_assert(act_block_h_datums % 2 == 0); // need to be even to read 2 in the body, due to packing of 2 indices in 1 uint32_t word
if constexpr (coalesce_window_inner_reads and window_inner == num_coalesced_reads) {
// coalesce reads along weight_size_w
reader_offset_idx = 0;
uint32_t act_l1_offset = 0;
uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act);

static_assert(coalesced_read_bytes <= NOC_MAX_BURST_SIZE);
// set_state uses just x/y from the get_noc_addr, addr is ignored
noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes);
uint32_t start_reader_idx = 0;
for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) {
#ifdef SPLIT_READER
if constexpr (cache_packed_reader_indices) {
for (uint32_t i = 0; i < act_block_h_datums_read; i++) {
local_packed_reader_indices[i] = packed_reader_indices_ptr[start_reader_idx+i];
}
}
#endif
for (uint32_t outer = 0; outer < window_outer; outer++) {
// Reset reader_idx to finish act_block_h_datums
reader_idx = start_reader_idx;

cb_reserve_back(cb_id_act, act_block_num_tiles_read);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act);
uint32_t reader_offset = act_l1_read_addr + (reader_offsets[reader_offset_idx] * conv_act_c_read_bytes);
// #pragma GCC unroll 4 // unroll didn't help, but act_block_h_datums (loop bound) being const does help
for (uint32_t bhd = 0; bhd < act_block_h_datums_read; bhd++) {
// local read from reader_index + reader_offset;
#ifdef SPLIT_READER
uint32_t two_reader_indices = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
#else // no split reader
uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx];
#endif
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

act_l1_offset = reader_offset + (reader_idx_1 * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

act_l1_offset = reader_offset + (reader_idx_2 * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act, act_block_num_tiles_read);

reader_offset_idx += window_inner;
// coalesce reads along weight_size_w
reader_offset_idx = 0;
uint32_t act_l1_offset = 0;
uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act);

static_assert(coalesced_read_bytes <= NOC_MAX_BURST_SIZE);
// set_state uses just x/y from the get_noc_addr, addr is ignored
noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), coalesced_read_bytes);
uint32_t start_reader_idx = 0;
for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) {
#ifdef SPLIT_READER
if constexpr (cache_packed_reader_indices) {
for (uint32_t i = 0; i < act_block_h_datums_read; i++) {
local_packed_reader_indices[i] = packed_reader_indices_ptr[start_reader_idx+i];
}
reader_offset_idx = 0;

start_reader_idx = reader_idx;
#ifdef SPLIT_READER
start_reader_idx += act_block_h_datums_read;
#endif
}

} else {
// NOTE: This code block expects reader_indices_ptr to be uint32_t (not packed uint16_t)
// Inner window dim is usually 3, so reading packed indices is complicated
// TODO: We could probably just remove this block is no convs use it

// no coalescing of reads
reader_offset_idx = 0;
uint32_t act_l1_offset = 0;
uint32_t act_l1_read_addr = get_read_ptr(cb_id_sharded_act);

static_assert(conv_act_c_read_bytes <= NOC_MAX_BURST_SIZE);
// set_state uses just x/y from the get_noc_addr, addr is ignored
noc_async_read_one_packet_set_state(get_noc_addr(act_l1_read_addr), conv_act_c_read_bytes);

uint32_t start_reader_idx = 0;
for (uint32_t bh = 0; bh < act_num_blocks_h; bh++) {
#endif
for (uint32_t outer = 0; outer < window_outer; outer++) {
// Reset reader_idx to finish act_block_h_datums
reader_idx = start_reader_idx;
cb_reserve_back(cb_id_act, act_block_num_tiles);

cb_reserve_back(cb_id_act, act_block_num_tiles_read);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act);
for (uint32_t bhd = 0; bhd < act_block_h_datums; bhd++) {
// when no read coalesing, main use case is window_inner == 1,
// and if window_inner is const this loop should be removed by the compiler
uint32_t reader_offset = act_l1_read_addr + (reader_offsets[reader_offset_idx] * conv_act_c_read_bytes);
// #pragma GCC unroll 4 // unroll didn't help, but act_block_h_datums (loop bound) being const does help
for (uint32_t bhd = 0; bhd < act_block_h_datums_read; bhd++) {
// local read from reader_index + reader_offset;
#ifdef SPLIT_READER
uint32_t packed_reader_idx = packed_reader_indices_ptr[reader_idx];
if constexpr (cache_packed_reader_indices) {
local_packed_reader_indices[bhd] = packed_reader_idx;
}
#else
uint32_t packed_reader_idx = packed_reader_indices_ptr[reader_idx];
uint32_t two_reader_indices = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
#else // no split reader
uint32_t two_reader_indices = packed_reader_indices_ptr[reader_idx];
#endif
for (uint32_t inner = 0; inner < window_inner; inner++) {
// local read from reader_index + reader_offset;
act_l1_offset = act_l1_read_addr + ((packed_reader_idx + reader_offsets[reader_offset_idx + inner]) * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += conv_act_c_read_bytes;
uint32_t reader_idx_1 = two_reader_indices & 0xffff;
uint32_t reader_idx_2 = two_reader_indices >> 16;

act_l1_offset = reader_offset + (reader_idx_1 * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

act_l1_offset = reader_offset + (reader_idx_2 * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += (coalesced_read_bytes + act_block_w_extra_align_bytes);

}
reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act, act_block_num_tiles);

reader_offset_idx += 3*window_inner;
for (uint32_t outer = 1; outer < window_outer; outer++) {
// Reset reader_idx to finish act_block_h_datums
reader_idx = start_reader_idx;
cb_reserve_back(cb_id_act, act_block_num_tiles);
uint32_t l1_write_addr_act = get_write_ptr(cb_id_act);
for (uint32_t bhd = 0; bhd < act_block_h_datums; bhd++) {
// when no read coalesing, main use case is window_inner == 1,
// and if window_inner is const this loop should be removed by the compiler
#ifdef SPLIT_READER
uint32_t packed_reader_idx = cache_packed_reader_indices ? local_packed_reader_indices[bhd] : packed_reader_indices_ptr[reader_idx];
#else
uint32_t packed_reader_idx = packed_reader_indices_ptr[reader_idx];
#endif
for (uint32_t inner = 0; inner < window_inner; inner++) {
// local read from reader_index + reader_offset;
act_l1_offset = act_l1_read_addr + ((packed_reader_idx + reader_offsets[reader_offset_idx + inner]) * conv_act_c_read_bytes);
noc_async_read_one_packet_with_state<true>(act_l1_offset, l1_write_addr_act);
l1_write_addr_act += conv_act_c_read_bytes;

}
reader_idx++;
}
noc_async_read_barrier();
cb_push_back(cb_id_act, act_block_num_tiles);

reader_offset_idx += 3*window_inner;
}
reader_offset_idx = 0;
start_reader_idx = reader_idx;
cb_push_back(cb_id_act, act_block_num_tiles_read);

reader_offset_idx += window_inner;
}
reader_offset_idx = 0;

start_reader_idx = reader_idx;
#ifdef SPLIT_READER
start_reader_idx += act_block_h_datums_read;
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,15 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(

uint32_t window_outer;
uint32_t window_inner;
if (weight_width_sliced) {

if (weight_width_sliced and weight_size_w == 3) {
window_outer = 1; // window_outer = 1 becasue all of filter window is processed in the inner loop
window_inner = 3; // window_inner = 9 / 3, ie. read 3 width coalesced
} else {
window_outer = num_blocks_act_w; // window_outer
window_inner = weight_size_h * weight_size_w / num_blocks_act_w; // window_inner
}

reader_defines["WINDOW_INNER"] = std::to_string(window_inner);
log_debug(LogOp, "window_outer: {}, window_inner: {}", window_outer, window_inner);

Expand Down Expand Up @@ -709,17 +711,17 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
}
}

bool read_3x3_window_in_inner_loop = false;
bool read_window_in_inner_loop = false;
uint32_t num_weight_cb_tiles = weight_block_h_ntiles * weight_block_w_ntiles / conv_act_c_blocks;
bool fully_buffer_weights = false;
uint32_t num_act_cb_tiles = act_block_h_ntiles * act_block_w_ntiles / conv_act_c_blocks;
// TODO: This flag should be set in kernel logic but need this for create_CB
if (a.memory_config().is_sharded() and weight_size_h == 3 and weight_size_w == 3 and
(stride_h == 1 or stride_h == 2) and weight_width_sliced) {
if (a.memory_config().is_sharded() and ((weight_size_h == 3 and weight_size_w == 3 and
(stride_h == 1 or stride_h == 2)) or (weight_size_h == 1 and weight_size_w == 1 and stride_h == 2)) and weight_width_sliced) {
// If conv_act_c_blocks > 1 and we have 2D conv with sharded input, we always read entire 3x3 window before
// pushing in reader/writer
// TODO: Generalize this to not make this assumption
read_3x3_window_in_inner_loop = true;
read_window_in_inner_loop = true;
num_weight_cb_tiles *= weight_size_h * weight_size_w;
num_act_cb_tiles *= weight_size_h * weight_size_w;
} else if (num_blocks_act_h_per_core > 1) {
Expand Down Expand Up @@ -800,10 +802,10 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(

compute_kernel = "tt_eager/tt_dnn/op_library/conv/kernels/conv_bmm_tilize_col_major_out_blocks.cpp";
// Input should always be sharded in this conv; always use reader kernel for input shard with halo and padding
if (weight_size_h == weight_size_w and weight_size_w > 1 and (stride_h == 1 or stride_h == 2)) {
if (weight_size_h == weight_size_w and weight_size_w >= 1 and (stride_h == 1 or stride_h == 2)) {
if (weight_width_sliced) {
// 2D conv
assert(read_3x3_window_in_inner_loop == true);
assert(read_window_in_inner_loop == true);
reader_kernel =
"tt_eager/tt_dnn/op_library/conv/kernels/"
"reader_conv_activations_2d_mcast_padded_with_halo_3x3_weights_v2.cpp";
Expand Down Expand Up @@ -872,7 +874,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
TT_ASSERT(false, "Sharded input not supported for this conv yet!");
}

if (read_3x3_window_in_inner_loop) {
if (read_window_in_inner_loop) {
const uint32_t window_size = weight_size_h * weight_size_w;
in0_block_w *= window_size;
in0_block_num_tiles *= window_size;
Expand Down Expand Up @@ -905,6 +907,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
(uint32_t)act_mcast_receiver_semaphore,
(uint32_t)in0_block_num_tiles * tilized_act_tile_size, // act_mcast_sender_size_bytes
(uint32_t)(transpose_mcast ? 1 : 0),
(uint32_t)pad_w,
};

// define for bias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def __init__(
filter_height == filter_width
and filter_height == 1
and stride_h == stride_w
and stride_h == 1
and pad_h == pad_w
and pad_h == 0
):
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
device);
}
// if 1x1 conv w/ stride 1, convert input tensor to tile layout if required
bool use_matmul_for_1x1_conv = kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] &&
bool use_matmul_for_1x1_conv = kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == stride[1] && stride[0] == 1 &&
padding[0] == 0 && padding[1] == 0 && dilation[0] == 1 && dilation[1] == 1 &&
groups == 1;
Tensor input_tensor_post_tm_out;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/operations/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def conv2d(
input_tensor = ttnn.to_device(input_tensor, device=device, memory_config=input_tensor_sharded_memory_config)
# since we resharded/moved the input tensor, we can deallocate it after halo op within composite conv
conv_config.deallocate_activation = True
is_1x1_conv = kernel_size == (1, 1) and stride[0] == stride[1] and padding == (0, 0)
is_1x1_conv = kernel_size == (1, 1) and stride[0] == stride[1] and stride[0] == 1 and padding == (0, 0)
if is_1x1_conv and input_tensor.layout != ttnn.TILE_LAYOUT:
input_tensor = ttnn.to_layout(input_tensor, ttnn.TILE_LAYOUT, dtype=conv_config.dtype)
input_is_on_device = ttnn.is_tensor_storage_on_device(input_tensor)
Expand Down

0 comments on commit 54bc6fd

Please sign in to comment.