Skip to content

Commit

Permalink
#3827: Update interleaved_to_sharded to only read necessary data for …
Browse files Browse the repository at this point in the history
…uneven shards
  • Loading branch information
tt-aho committed Nov 30, 2023
1 parent f2e9220 commit 17739ba
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ void kernel_main() {
.data_format = data_format
};

uint32_t tile_id = start_id;
uint32_t curr_tile_id = start_id;
cb_reserve_back(cb_id_in0, block_num_tiles);
uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
for (uint32_t h = 0; h < block_height_tiles; h++) {
uint32_t tile_id = curr_tile_id;
for (uint32_t w = 0; w < block_width_tiles; w++) {
noc_async_read_tile(tile_id, s, l1_write_addr);
tile_id++;
l1_write_addr += tile_bytes;
noc_async_read_barrier();
}
tile_id += input_width_offset_tiles;
curr_tile_id += input_width_offset_tiles;
}
cb_push_back(cb_id_in0, block_num_tiles);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
tt_metal::Program program{};

uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width,
num_units_per_shard_height, num_units_offset, num_units_per_row;
num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_per_shard_height_last,
num_units_per_shard_width_last;

tt_metal::Device* device = input.device();

tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.dtype());
tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype());

auto shard_spec = output.shard_spec().value();
auto shard_strategy = output.memory_config().memory_layout;

bool rm_orientation = shard_spec.shard_orientation == ShardOrientation::ROW_MAJOR;

CoreCoord end_core = (*shard_spec.shard_grid.ranges().rbegin()).end;
if (input.layout() == Layout::TILE) {
num_units = input.volume() / TILE_HW;
input_unit_size = tt_metal::detail::TileSize(input_cb_data_format);
Expand All @@ -42,7 +45,12 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
num_units_per_shard_width = shard_spec.shard_shape[1] / TILE_WIDTH;
num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width;
num_units_per_row = input.shape()[-1] / TILE_WIDTH;
num_units_offset = num_units_per_row - num_units_per_shard_width;
num_units_offset = num_units_per_row;
uint32_t num_units_height = input.volume() / input.shape()[-1] / TILE_HEIGHT;
num_units_per_shard_height_last =
num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height);
num_units_per_shard_width_last =
num_units_per_shard_width - (round_up(num_units_per_row, num_units_per_shard_width) - num_units_per_row);
} else {
num_units = (input.volume() / input.shape()[-1] / shard_spec.shard_shape[0]) *
(input.shape()[-1] / shard_spec.shard_shape[1]);
Expand All @@ -53,6 +61,11 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width;
num_units_per_row = input.shape()[-1] * input.element_size();
num_units_offset = 1;
uint32_t num_units_height = input.volume() / input.shape()[-1];
num_units_per_shard_height_last =
num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height);
num_units_per_shard_width_last =
input_unit_size - (round_up(num_units_per_row, input_unit_size) - num_units_per_row);
}

bool convert_df = input_cb_data_format != output_cb_data_format;
Expand Down Expand Up @@ -108,8 +121,7 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(

unary_reader_kernel_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/"
"reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp",
"tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp",
all_cores,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_1,
Expand Down Expand Up @@ -144,13 +156,40 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
const auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, rm_orientation);
for (const auto& core : cores) {
if (input.layout() == Layout::TILE) {
uint32_t shard_height = num_units_per_shard_height;
uint32_t shard_width = num_units_per_shard_width;
if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) {
if (core == end_core) {
shard_height = num_units_per_shard_height_last;
}
} else if (shard_strategy == TensorMemoryLayout::WIDTH_SHARDED) {
if (core == end_core) {
shard_width = num_units_per_shard_width_last;
}
} else if (shard_strategy == TensorMemoryLayout::BLOCK_SHARDED) {
if (rm_orientation) {
if (core.x == end_core.x) {
shard_width = num_units_per_shard_width_last;
}
if (core.y == end_core.y) {
shard_height = num_units_per_shard_height_last;
}
} else {
if (core.y == end_core.y) {
shard_width = num_units_per_shard_width_last;
}
if (core.x == end_core.x) {
shard_height = num_units_per_shard_height_last;
}
}
}
tt_metal::SetRuntimeArgs(
program,
unary_reader_kernel_id,
core,
{src_buffer->address(),
num_units_per_shard_height,
num_units_per_shard_width,
shard_height,
shard_width,
num_units_offset,
num_units_per_shard,
curr_idx_h + curr_idx_w});
Expand All @@ -160,14 +199,41 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core(
curr_idx_h += num_units_per_row * num_units_per_shard_height;
}
} else {
uint32_t shard_height = num_units_per_shard_height;
uint32_t shard_width = input_unit_size;
if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) {
if (core.x == end_core.x && core.y == end_core.y) {
shard_height = num_units_per_shard_height_last;
}
} else if (shard_strategy == TensorMemoryLayout::WIDTH_SHARDED) {
if (core.x == end_core.x && core.y == end_core.y) {
shard_width = num_units_per_shard_width_last;
}
} else if (shard_strategy == TensorMemoryLayout::BLOCK_SHARDED) {
if (rm_orientation) {
if (core.x == end_core.x) {
shard_width = num_units_per_shard_width_last;
}
if (core.y == end_core.y) {
shard_height = num_units_per_shard_height_last;
}
} else {
if (core.y == end_core.y) {
shard_width = num_units_per_shard_width_last;
}
if (core.x == end_core.x) {
shard_height = num_units_per_shard_height_last;
}
}
}
tt_metal::SetRuntimeArgs(
program,
unary_reader_kernel_id,
core,
{src_buffer->address(),
num_units_per_row,
num_units_per_shard_height,
input_unit_size,
shard_height,
shard_width,
curr_idx_w,
curr_idx_h});
curr_idx_w += input_unit_size;
Expand Down Expand Up @@ -220,7 +286,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core(
auto shard_strategy = input.memory_config().memory_layout;

bool rm_orientation = shard_spec.shard_orientation == ShardOrientation::ROW_MAJOR;
CoreCoord end_core = (*shard_spec.shard_grid.ranges().begin()).end;
CoreCoord end_core = (*shard_spec.shard_grid.ranges().rbegin()).end;
if (output.layout() == Layout::TILE) {
num_units = input.volume() / TILE_HW;
input_unit_size = tt_metal::detail::TileSize(input_cb_data_format);
Expand Down

0 comments on commit 17739ba

Please sign in to comment.