From 37168908c52b63f48d9b8dc247e78bed5c5b15b1 Mon Sep 17 00:00:00 2001 From: Raymond Kim <109366641+tt-rkim@users.noreply.github.com> Date: Wed, 15 May 2024 10:21:20 -0400 Subject: [PATCH] #8260: Revert "#8260: reshard uneven shard" because it breaks perf (#8500) pipelines + nightly This reverts commit c95e6af65219cf3d5aca6bc5bbed5215c9557eaa. --- models/utility_functions.py | 2 +- .../unit_testing/misc/test_reshard.py | 77 +----- tests/ttnn/unit_tests/operations/test_core.py | 27 +- .../kernels/dataflow/reshard_reader.cpp | 21 +- .../multi_core/sharded_op_multi_core.cpp | 248 ++++++++---------- .../tt_dnn/op_library/sharded/sharded_op.hpp | 1 - 6 files changed, 132 insertions(+), 244 deletions(-) diff --git a/models/utility_functions.py b/models/utility_functions.py index 925b1cda7f9..b43664a62e9 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -1121,7 +1121,7 @@ def get_debug_tensor(num_pages_width, num_pages_height, dtype, page_width=32, pa tile_row = None for col_idx in range(0, int(num_pages_width)): tile_idx = col_idx + num_pages_width * row_idx - tile = torch.full((1, 1, page_height, page_width), tile_idx, dtype=dtype) + tile = torch.full((1, 1, page_width, page_height), tile_idx + 1, dtype=dtype) if tile_row == None: tile_row = tile else: diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py index ced941eb058..fd35b9b5dc3 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_reshard.py @@ -14,25 +14,7 @@ from enum import Enum -from models.utility_functions import skip_for_grayskull, get_debug_tensor - - -tt_dtype_to_torch_dtype = { - ttl.tensor.DataType.UINT32: torch.int32, - ttl.tensor.DataType.UINT16: torch.int16, - ttl.tensor.DataType.BFLOAT16: torch.bfloat16, - ttl.tensor.DataType.BFLOAT8_B: torch.float, -} -TILE_WIDTH = 32 -TILE_HEIGHT = 32 - - -def get_tensor(shape, dtype): - if dtype in {torch.int16, torch.int32}: - torch_tensor = torch.randint(0, 1024, shape, dtype=dtype) - else: - torch_tensor = torch.rand(shape, dtype=dtype) - return torch_tensor +from models.utility_functions import skip_for_wormhole_b0, skip_for_grayskull def run_reshard_test( @@ -49,16 +31,10 @@ def run_reshard_test( output_sharding_scheme, tt_dtype, ): - full_grid = device.compute_with_storage_grid_size() - input_shard_grid_set = set() for _input_shard_grid in input_shard_grid: compute_grid_start = ttl.tensor.CoreCoord(_input_shard_grid[0][0], _input_shard_grid[0][1]) compute_grid_end = ttl.tensor.CoreCoord(_input_shard_grid[1][0], _input_shard_grid[1][1]) - if compute_grid_start.x >= full_grid.x or compute_grid_start.y >= full_grid.y: - pytest.skip("Illegal input core_grid") - if compute_grid_end.x >= full_grid.x or compute_grid_end.y >= full_grid.y: - pytest.skip("Illegal input core_grid") input_shard_grid_set.add(ttl.tensor.CoreRange(compute_grid_start, compute_grid_end)) input_shard_grid = ttl.tensor.CoreRangeSet(input_shard_grid_set) @@ -67,10 +43,6 @@ def run_reshard_test( for _output_shard_grid in output_shard_grid: compute_grid_start = ttl.tensor.CoreCoord(_output_shard_grid[0][0], _output_shard_grid[0][1]) compute_grid_end = ttl.tensor.CoreCoord(_output_shard_grid[1][0], _output_shard_grid[1][1]) - if compute_grid_start.x >= full_grid.x or compute_grid_start.y >= full_grid.y: - pytest.skip("Illegal output core_grid") - if compute_grid_end.x >= full_grid.x or compute_grid_end.y >= full_grid.y: - pytest.skip("Illegal output core_grid") output_shard_grid_set.add(ttl.tensor.CoreRange(compute_grid_start, compute_grid_end)) output_shard_grid = ttl.tensor.CoreRangeSet(output_shard_grid_set) @@ -84,27 +56,7 @@ def run_reshard_test( memory_layout=ttl.tensor.TensorMemoryLayout.INTERLEAVED, buffer_type=ttl.tensor.BufferType.DRAM, ) - debug = True - dtype = tt_dtype_to_torch_dtype[tt_dtype] - if debug: - if input_layout == ttl.tensor.Layout.TILE: - num_pages_height = (input_shape[0] * input_shape[1] * input_shape[2]) / 32 - num_pages_width = input_shape[3] / 32 - page_height = 32 - page_width = 32 - else: - page_width_input = input_shard_shape[1] - page_width_output = output_shard_shape[1] - page_height = 1 - page_width = int(math.gcd(page_width_input, page_width_output)) - num_pages_height = int(input_shape[0] * input_shape[1] * input_shape[2]) - num_pages_width = int(input_shape[3] / page_width) - torch_tensor = get_debug_tensor( - num_pages_width, num_pages_height, dtype, page_width=page_width, page_height=page_height - ) - else: - torch_tensor = get_tensor(input_shape, dtype) - + torch_tensor = torch.randn(input_shape).bfloat16() tt_tensor_sharded = ttl.tensor.Tensor(torch_tensor, tt_dtype).to(input_layout) tt_tensor_sharded = tt_tensor_sharded.to(device, dram_memory_config) tt_tensor_sharded = ttl.tensor.interleaved_to_sharded( @@ -129,6 +81,7 @@ def run_reshard_test( return torch_tensor, torch_tensor_after_round_trip +@skip_for_wormhole_b0() @pytest.mark.parametrize( "input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme", [ @@ -204,18 +157,6 @@ def run_reshard_test( ttl.tensor.ShardOrientation.COL_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, ), - ( - [1, 1, 160, 64], - ttl.tensor.Layout.TILE, - [[(0, 0), (0, 4)]], - (32, 64), - ttl.tensor.ShardOrientation.ROW_MAJOR, - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - [[(0, 0), (1, 1)]], - (96, 32), - ttl.tensor.ShardOrientation.COL_MAJOR, - ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, - ), ], ) @pytest.mark.parametrize("tt_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) @@ -316,14 +257,14 @@ def test_reshard_rn50( "input_shape, input_layout, input_shard_grid, input_shard_shape, input_shard_orientation, input_sharding_scheme, output_shard_grid, output_shard_shape, output_shard_orientation, output_sharding_scheme", [ ( - [1, 1, 160, 64], + [1, 1, 32, 6272], ttl.tensor.Layout.TILE, - [[(0, 0), (0, 4)]], - (32, 64), + [[(0, 0), (6, 6)]], + (32, 128), ttl.tensor.ShardOrientation.ROW_MAJOR, - ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - [[(0, 0), (1, 1)]], - (96, 32), + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, + [[(0, 0), (0, 6)]], + (32, 1024), ttl.tensor.ShardOrientation.COL_MAJOR, ttl.tensor.TensorMemoryLayout.BLOCK_SHARDED, ), diff --git a/tests/ttnn/unit_tests/operations/test_core.py b/tests/ttnn/unit_tests/operations/test_core.py index 0ae65e2dec9..8d2dfb15e54 100644 --- a/tests/ttnn/unit_tests/operations/test_core.py +++ b/tests/ttnn/unit_tests/operations/test_core.py @@ -181,23 +181,6 @@ None, None, ), - ( - 160, - 64, - ttnn.TILE_LAYOUT, - dict( - core_grid=ttnn.CoreGrid(y=5, x=1), - strategy=ttnn.ShardStrategy.HEIGHT, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - ), - dict( - core_grid=ttnn.CoreGrid(y=2, x=2), - strategy=ttnn.ShardStrategy.BLOCK, - orientation=ttnn.ShardOrientation.COL_MAJOR, - ), - (32, 64), - (32, 96), - ), ], ) def test_reshard( @@ -210,11 +193,10 @@ def test_reshard( input_override, output_override, ): - if isinstance(input_sharded_memory_config_args["core_grid"], (ttnn.CoreGrid)): - if device.core_grid.y < input_sharded_memory_config_args["core_grid"].y: - pytest.skip() - if device.core_grid.y < output_sharded_memory_config_args["core_grid"].y: - pytest.skip() + if device.core_grid.y < input_sharded_memory_config_args["core_grid"].y: + pytest.skip() + if device.core_grid.y < output_sharded_memory_config_args["core_grid"].y: + pytest.skip() input_shape = [1, 1, input_height, input_width] torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16) @@ -235,6 +217,7 @@ def test_reshard( output_shard_memory_config = ttnn.create_sharded_memory_config( output_override, **output_sharded_memory_config_args, use_height_and_width_as_shard_shape=True ) + # interleaved_to_sharded sharded_input_tensor = ttnn.to_memory_config(interleaved_input_tensor, input_shard_memory_config) diff --git a/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/reshard_reader.cpp b/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/reshard_reader.cpp index 8cb82f7bf42..265fc990503 100644 --- a/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/reshard_reader.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/reshard_reader.cpp @@ -35,28 +35,25 @@ void kernel_main() { const uint32_t start_y = get_arg_val(y_offset + start_y_index); const uint32_t stride_data_offset = get_arg_val(arg_index++); - const uint32_t stride_size_num_strides_skip = get_arg_val(arg_index++); - const uint32_t num_strides = ((stride_size_num_strides_skip) & mask_short) >> 8; - const bool skip = (((stride_size_num_strides_skip) & mask_byte) == 1); + const uint32_t stride_size_num_strides = get_arg_val(arg_index++); + const uint32_t num_strides = ((stride_size_num_strides) & mask_short); const uint32_t stride_data = ((stride_data_offset >> 16)) * page_size; const uint32_t offset = ((stride_data_offset) & mask_short) * page_size; - const uint32_t num_pages_per_stride = (stride_size_num_strides_skip >> 16); - const uint32_t stride_size = num_pages_per_stride * page_size; + const uint32_t stride_size = ((stride_size_num_strides >> 16)) * page_size; uint32_t addr_offset = offset; uint32_t core_id_x_index = start_x_index; uint32_t core_id_y_index = start_y_index; for(uint32_t stride_idx = 0; stride_idx < num_strides; stride_idx++) { - if(!skip) { - uint32_t core_id_x = get_arg_val(core_id_x_index); - uint32_t core_id_y = get_arg_val(y_offset + core_id_y_index); - uint64_t noc_address = get_noc_addr(core_id_x, core_id_y, - input_shard_addr + addr_offset); - noc_async_read(noc_address, l1_write_addr, stride_size); - } + + uint32_t core_id_x = get_arg_val(core_id_x_index); + uint32_t core_id_y = get_arg_val(y_offset + core_id_y_index); + uint64_t noc_address = get_noc_addr(core_id_x, core_id_y, + input_shard_addr + addr_offset); + noc_async_read(noc_address, l1_write_addr, stride_size); l1_write_addr+=stride_size; if(stride_x == 0 and stride_y == 0) { addr_offset += (stride_data + stride_size); diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index dc0e65e71cf..7e6267c814d 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -575,25 +575,28 @@ std::unordered_map> get_core_page_ranges( const auto& input_page_to_local_page_mapping = input_buffer_page_mapping.host_page_to_local_shard_page_mapping_; const auto& host_page_to_input_page_mapping = input_buffer_page_mapping.host_page_to_dev_page_mapping_; - auto output_cores = output_buffer_page_mapping.all_cores_; + auto num_pages = std::min(output_shard_to_host_mapping.size(), input_buffer->num_dev_pages()); + // First get output_core to vector< pair (num_pages_in_output) - std::vector >>> output_core_to_vector_input_core_page(output_cores.size()); + std::unordered_map>> output_core_to_vector_input_core_page; - for (uint32_t output_page_id = 0; output_page_id < output_buffer->num_dev_pages(); output_page_id++) { - auto output_core_id = output_buffer_page_mapping.dev_page_to_core_mapping_[output_page_id]; - TT_ASSERT(output_core_id < output_cores.size()); + for (uint32_t output_page_id = 0; output_page_id < num_pages; output_page_id++) { + auto output_core = output_buffer_page_mapping.all_cores_[output_buffer_page_mapping.dev_page_to_core_mapping_[output_page_id]]; auto host_page = output_shard_to_host_mapping[output_page_id]; - std::optional > mapped_page = std::nullopt; if(host_page.has_value()) { auto input_page = host_page_to_input_page_mapping[host_page.value()]; auto local_input_page = input_page_to_local_page_mapping[host_page.value()]; auto input_core = input_buffer_page_mapping.all_cores_[input_buffer_page_mapping.dev_page_to_core_mapping_[input_page]]; - mapped_page = std::make_optional > ({input_core, local_input_page}); + if (output_core_to_vector_input_core_page.find(output_core) == output_core_to_vector_input_core_page.end()) { + output_core_to_vector_input_core_page[output_core] = {{input_core, local_input_page}}; + } else { + output_core_to_vector_input_core_page[output_core].push_back({input_core, local_input_page}); + } } - output_core_to_vector_input_core_page[output_core_id].push_back(mapped_page); } // now compress to output_core to vector (num_page_ranges_in_output) + auto output_cores = corerange_to_cores(output_buffer->shard_spec().grid()); std::unordered_map> ret_map; ret_map.reserve(output_cores.size()); @@ -601,166 +604,131 @@ std::unordered_map> get_core_page_ranges( auto device = input_buffer->device(); auto full_grid = device->compute_with_storage_grid_size(); CoreCoord end_core = (*output_buffer->shard_spec().grid().ranges().rbegin()).end; - uint32_t output_core_id = 0; + uint32_t output_core_id; for (auto output_core : output_cores) { ret_map.try_emplace(output_core, std::vector{}); - const auto& input_cores_with_pages = output_core_to_vector_input_core_page[output_core_id]; + + const auto& input_cores_with_pages = output_core_to_vector_input_core_page.at(output_core); auto it = input_cores_with_pages.begin(); const auto end = input_cores_with_pages.end(); while (it != end) { - //hit padding, will see how many consecutive pages has padding to make a padded range - if(!it->has_value()) { + const auto start_core = it->first; + const auto start_page = it->second; + auto expected_next_page = start_page + 1; + Stride stride = Stride{.core = {0,0} , .data = 0}; + if ((it + 1) == end) { + ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=1, .stride=stride, .num_strides=1}); + it = end; + } + else { + //first get a single stride, go through the number of consecutive pages in the same core auto consecutive_it = it+1; auto last_it_consec = it; while(consecutive_it != end) { - if(consecutive_it->has_value()) { + auto next_input_page = *(consecutive_it); + auto curr_input_page = *(last_it_consec); + // diff core , not consecutive + if(curr_input_page.first != next_input_page.first) { + break; + } + //not consecutive + else if ((curr_input_page.second + 1) != next_input_page.second) { break; } last_it_consec = consecutive_it; consecutive_it = consecutive_it+1; } uint32_t stride_size = std::distance(it, last_it_consec) + 1; - ret_map[output_core].push_back(PageStride{.start_core = output_core, .start_data=0, .stride_size=stride_size, .stride=Stride{.core = {0,0} , .data = 0}, .num_strides=1, .skip=true}); - it += stride_size; - } - else { - const auto start_core = it->value().first; - const auto start_page = it->value().second; - auto expected_next_page = start_page + 1; - Stride stride = Stride{.core = {0,0} , .data = 0}; - if ((it + 1) == end) { - ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->value().second, .stride_size=1, .stride=stride, .num_strides=1, .skip = false}); - it = end; - } - else { - //first get a single stride, go through the number of consecutive pages in the same core - auto consecutive_it = it+1; - auto last_it_consec = it; - while(consecutive_it != end) { - auto next_input_page = *(consecutive_it); - auto curr_input_page = *(last_it_consec); - // diff core , not consecutive - if(curr_input_page.value().first != next_input_page.value().first) { - break; - } - //not consecutive - else if ((curr_input_page.value().second + 1) != next_input_page.value().second) { - break; - } - last_it_consec = consecutive_it; - consecutive_it = consecutive_it+1; + auto stride_it = it + stride_size; + auto last_it_stride = stride_it - 1; + + // if stride_range is within same core + // the jump in data is end of curr - end last stride + // if stride range is in diff core + // jump in data is curr - beginning of last stride + uint32_t data_stride; + if((stride_it != end) and (stride_it != it)){ + // data stride within core + if(stride_it->first == last_it_stride->first and (stride_it->second > last_it_stride->second) ) { + auto next_input_page = *(stride_it); + auto prev_input_page = *(last_it_stride); + data_stride = next_input_page.second - prev_input_page.second - 1; + stride = Stride{.core = {next_input_page.first.x - prev_input_page.first.x, next_input_page.first.y - prev_input_page.first.y}, + .data = data_stride}; } - uint32_t stride_size = std::distance(it, last_it_consec) + 1; - auto stride_it = it + stride_size; - auto last_it_stride = stride_it - 1; - - TT_ASSERT((stride_it == end) or stride_it->has_value()); - TT_ASSERT(last_it_stride->has_value()); - // if stride_range is within same core - // the jump in data is end of curr - end last stride - // if stride range is in diff core - // jump in data is curr - beginning of last stride - uint32_t data_stride; - if((stride_it != end) and (stride_it != it)){ - // data stride within core - if(stride_it->has_value() and - stride_it->value().first == last_it_stride->value().first and - (stride_it->value().second > last_it_stride->value().second) ) - { - auto next_input_page = *(stride_it); - auto prev_input_page = *(last_it_stride); - TT_ASSERT(prev_input_page.has_value()); - TT_ASSERT(next_input_page.has_value()); - data_stride = next_input_page.value().second - prev_input_page.value().second - 1; - stride = Stride{.core = {0, 0}, - .data = data_stride}; - } - // strided core but same data - // currently only handling increasing cores within same stride - // TODO : negative strides for cores - else if(stride_it->has_value() and - (stride_it->value().first != last_it_stride->value().first) and - (stride_it->value().first.x >= it->value().first.x and - stride_it->value().first.y >= it->value().first.y) and - (stride_it->value().second == it->value().second)) { - auto next_input_page = *(stride_it); - auto prev_input_page = *it; - TT_ASSERT(prev_input_page.has_value()); - TT_ASSERT(next_input_page.has_value()); - data_stride = 0; - stride = Stride{.core = {next_input_page.value().first.x - prev_input_page.value().first.x, next_input_page.value().first.y - prev_input_page.value().first.y}, - .data = data_stride}; - } - // diff data and diff core, not handled yet - else { - TT_ASSERT(it->has_value()); - ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->value().second, .stride_size=stride_size, .stride=stride, .num_strides=1, .skip=false}); - it = stride_it; - continue; - } - //TODO add stride of data and core + // strided core but same data + // currently only handling increasing cores within same stride + // TODO : negative strides for cores + else if((stride_it->first != last_it_stride->first) and (stride_it->first.x >= it->first.x and stride_it->first.y >= it->first.y) and (stride_it->second == it->second)) { + //else { + auto next_input_page = *(stride_it); + auto prev_input_page = *it; + data_stride = 0; + stride = Stride{.core = {next_input_page.first.x - prev_input_page.first.x, next_input_page.first.y - prev_input_page.first.y}, + .data = data_stride}; } - // only single stride + // diff data and diff core, not handled yet else { - data_stride = 0; + ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=1}); + it = stride_it; + continue; } + //TODO add stride of data and core + } + // only single stride + else { + data_stride = 0; + } - TT_ASSERT(stride.core.x < full_grid.x and stride.core.y < full_grid.y); - TT_ASSERT(data_stride < output_buffer->num_pages()); - auto stride_start = stride_it; - uint32_t num_strides = 1; - while(stride_it != end) { - bool stride_not_complete = false; - auto stride_it_inner = stride_it + 1; - auto last_it_stride_inner = stride_it; - for(uint32_t i=0; inum_pages()); + auto stride_start = stride_it; + uint32_t num_strides = 1; + while(stride_it != end) { + bool stride_not_complete = false; + auto stride_it_inner = stride_it + 1; + auto last_it_stride_inner = stride_it; + for(uint32_t i=0; ihas_value()); - ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->value().second, .stride_size=stride_size, .stride=stride, .num_strides=num_strides, .skip=false}); - it = stride_it; } + ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=num_strides}); + it = stride_it; } } - output_core_id++; } return ret_map; @@ -817,7 +785,7 @@ std::vector get_runtime_args_for_given_ranges(const std::vector