Skip to content

Commit

Permalink
#0: Cleanup bmm multi core reuse optimized ORTAs
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jun 7, 2024
1 parent 7575dd9 commit 53d0624
Showing 1 changed file with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@ operation::ProgramWithCallbacks create_program(
}
auto cb_output = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), output_cb_config);

std::vector<KernelHandle> reader_kernel_ids;
std::vector<KernelHandle> writer_kernel_ids;

// Write runtime args to device
std::vector<uint32_t> mm_reader_args = {
(std::uint32_t) num_blocks, // num_blocks
Expand Down Expand Up @@ -370,15 +367,13 @@ operation::ProgramWithCallbacks create_program(

tt_metal::SetRuntimeArgs(program, mm_kernel_in0_reader_id, core, mm_reader_args);
tt_metal::SetRuntimeArgs(program, mm_kernel_in1_reader_writer_id, core, mm_writer_args);
reader_kernel_ids.push_back(mm_kernel_in0_reader_id);
writer_kernel_ids.push_back(mm_kernel_in1_reader_writer_id);

num_blocks_written += num_output_blocks_per_core;
}

auto override_runtime_arguments_callback = [
reader_kernel_ids,
writer_kernel_ids,
mm_kernel_in0_reader_id,
mm_kernel_in1_reader_writer_id,
cb_src0,
cb_src1,
cb_output,
Expand All @@ -398,23 +393,28 @@ operation::ProgramWithCallbacks create_program(

auto dst_buffer = output_tensors.at(0).buffer();

const bool src0_sharded = input_tensors.at(0).memory_config().is_sharded();
const bool src1_sharded = input_tensors.at(1).memory_config().is_sharded();
const bool out_sharded = output_tensors.at(0).memory_config().is_sharded();
const bool src0_sharded = input_tensors[0].memory_config().is_sharded();
const bool src1_sharded = input_tensors[1].memory_config().is_sharded();
const bool out_sharded = output_tensors[0].memory_config().is_sharded();

const bool update_reader_args = !src0_sharded;

const bool update_writer_args = !(src1_sharded and out_sharded);

if (update_reader_args || update_writer_args) {

auto& reader_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in0_reader_id);

if (!(src0_sharded and src1_sharded and out_sharded)) {
for (uint32_t i = 0; i < cores.size(); ++i) {
const CoreCoord& core = cores[i];
auto& writer_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in1_reader_writer_id);

if (!src0_sharded) {
auto reader_kernel_id = reader_kernel_ids.at(i);
auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
for (const auto& core : cores) {
if (update_reader_args) {
auto &runtime_args = reader_runtime_args_by_core[core.x][core.y];
runtime_args[4] = src_buffer_a->address();
}

if (!(src1_sharded and out_sharded)) {
auto writer_kernel_id = writer_kernel_ids.at(i);
auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
if (update_writer_args) {
auto &runtime_args = writer_runtime_args_by_core[core.x][core.y];
runtime_args[5] = src_buffer_b->address();
runtime_args[13] = dst_buffer->address();
}
Expand Down

0 comments on commit 53d0624

Please sign in to comment.