Skip to content

Commit

Permalink
#8260: add skip padding in reshard kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed May 22, 2024
1 parent e2c5f52 commit 959a007
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,29 @@ void kernel_main() {
const uint32_t start_y = get_arg_val<uint32_t>(y_offset + start_y_index);

const uint32_t stride_data_offset = get_arg_val<uint32_t>(arg_index++);
const uint32_t stride_size_num_strides = get_arg_val<uint32_t>(arg_index++);
const uint32_t num_strides = ((stride_size_num_strides) & mask_short);
const uint32_t stride_size_num_strides_skip = get_arg_val<uint32_t>(arg_index++);
const uint32_t num_strides = ((stride_size_num_strides_skip) & mask_short) >> 8;
const bool skip = (((stride_size_num_strides_skip) & mask_byte) == 1);


const uint32_t stride_data = ((stride_data_offset >> 16)) * page_size;
const uint32_t offset = ((stride_data_offset) & mask_short) * page_size;
const uint32_t stride_size = ((stride_size_num_strides >> 16)) * page_size;
const uint32_t num_pages_per_stride = (stride_size_num_strides_skip >> 16);
const uint32_t stride_size = num_pages_per_stride * page_size;

uint32_t addr_offset = offset;
uint32_t core_id_x_index = start_x_index;
uint32_t core_id_y_index = start_y_index;

for(uint32_t stride_idx = 0; stride_idx < num_strides; stride_idx++) {

uint32_t core_id_x = get_arg_val<uint32_t>(core_id_x_index);
uint32_t core_id_y = get_arg_val<uint32_t>(y_offset + core_id_y_index);
uint64_t noc_address = get_noc_addr(core_id_x, core_id_y,
input_shard_addr + addr_offset);
noc_async_read(noc_address, l1_write_addr, stride_size);
l1_write_addr+=stride_size;
if(!skip) {
uint32_t core_id_x = get_arg_val<uint32_t>(core_id_x_index);
uint32_t core_id_y = get_arg_val<uint32_t>(y_offset + core_id_y_index);
uint64_t noc_address = get_noc_addr(core_id_x, core_id_y,
input_shard_addr + addr_offset);
noc_async_read(noc_address, l1_write_addr, stride_size);
l1_write_addr+=stride_size;
}
if(stride_x == 0 and stride_y == 0) {
addr_offset += (stride_data + stride_size);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ std::unordered_map<CoreCoord, std::vector<PageStride>> get_core_page_ranges(
auto expected_next_page = start_page + 1;
Stride stride = Stride{.core = {0,0} , .data = 0};
if ((it + 1) == end) {
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=1, .stride=stride, .num_strides=1});
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=1, .stride=stride, .num_strides=1, .skip=false});
it = end;
}
else {
Expand Down Expand Up @@ -672,7 +672,7 @@ std::unordered_map<CoreCoord, std::vector<PageStride>> get_core_page_ranges(
}
// diff data and diff core, not handled yet
else {
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=1});
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=1, .skip=false});
it = stride_it;
continue;
}
Expand Down Expand Up @@ -725,7 +725,7 @@ std::unordered_map<CoreCoord, std::vector<PageStride>> get_core_page_ranges(
break;
}
}
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=num_strides});
ret_map[output_core].push_back(PageStride{.start_core = start_core, .start_data=it->second, .stride_size=stride_size, .stride=stride, .num_strides=num_strides, .skip=false});
it = stride_it;
}
}
Expand Down Expand Up @@ -785,7 +785,7 @@ std::vector<uint32_t> get_runtime_args_for_given_ranges(const std::vector<uint32
runtime_args.push_back((uint32_t)core_start_stride); //start_x
uint32_t stride_data_start = (ps.stride.data << 16) | (start_data);
runtime_args.push_back((uint32_t)stride_data_start); //stride_data
uint32_t stride_size_num_strides = (ps.stride_size << 16) | (num_strides);
uint32_t stride_size_num_strides = (ps.stride_size << 16) | (num_strides << 8) | ((uint32_t)ps.skip);
runtime_args.push_back((uint32_t)stride_size_num_strides); // stride_size
num_output_pages += ps.stride_size * num_strides;
}
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/sharded/sharded_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct PageStride {
uint32_t stride_size; //number of pages per stride
Stride stride;
uint32_t num_strides;
bool skip;
};

struct CorePageRange {
Expand Down

0 comments on commit 959a007

Please sign in to comment.