Skip to content

Commit

Permalink
#0: Add helper function to create CBs
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-zaretskiy committed May 31, 2024
1 parent 1c90b4f commit ab7e272
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 136 deletions.
50 changes: 50 additions & 0 deletions tt_eager/tt_dnn/op_library/cb_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/host_api.hpp"

namespace tt::tt_metal {

template <size_t N>
std::tuple<std::array<CB, N>, CBHandle> create_cb(
const CB (&cbs)[N],
Program &program,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
uint32_t page_size,
uint32_t num_pages,
const tt::DataFormat data_format,
Buffer *buffer = nullptr) {
std::map<uint8_t, tt::DataFormat> data_format_spec = {};
for (auto cb : cbs) {
data_format_spec[cb] = data_format;
}

auto cb_config = CircularBufferConfig(num_pages * page_size, data_format_spec);
for (auto cb : cbs) {
cb_config.set_page_size(cb, page_size);
}

if (buffer != nullptr) {
cb_config.set_globally_allocated_address(*buffer);
}

std::array<CB, N> cbs_out;
std::copy(cbs, cbs + N, cbs_out.begin());
return std::make_tuple(cbs_out, tt_metal::CreateCircularBuffer(program, core_spec, cb_config));
}

inline std::tuple<CB, CBHandle> create_cb(
CB cb,
Program &program,
const std::variant<CoreCoord, CoreRange, CoreRangeSet> &core_spec,
uint32_t page_size,
uint32_t num_pages,
const tt::DataFormat data_format,
Buffer *buffer = nullptr) {
CB cbs[] = {cb};
auto [_, handle] = create_cb(cbs, program, core_spec, page_size, num_pages, data_format, buffer);
return std::make_tuple(cb, handle);
}

} // namespace tt::tt_metal
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <math.h>

#include "tt_dnn/op_library/cb_utils.hpp"
#include "tt_dnn/op_library/math.hpp"
#include "tt_dnn/op_library/operation.hpp"
#include "tt_dnn/op_library/work_split_tilize.hpp"
Expand Down Expand Up @@ -35,21 +36,10 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T
auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] =
split_blocks_for_tilize(grid_size, nblocks);

uint32_t src0_cb_index = CB::c_in0;
uint32_t num_input_tiles = ntiles_per_block;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);

uint32_t output_cb_index = CB::c_out0;
uint32_t num_output_tiles = ntiles_per_block;
tt_metal::CircularBufferConfig cb_output_config =
tt_metal::CircularBufferConfig(
num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
create_cb(CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_block, input_cb_data_format);

auto [output_cb_index, _] =
create_cb(CB::c_out0, program, all_cores, output_single_tile_size, ntiles_per_block, output_cb_data_format);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down Expand Up @@ -204,23 +194,23 @@ operation::ProgramWithCallbacks tilize_multi_core_sharded(const Tensor& input, T
uint32_t num_cores_x = device->compute_with_storage_grid_size().x;
uint32_t num_cores = all_cores.num_cores();

uint32_t src0_cb_index = CB::c_in0;
uint32_t num_input_tiles = num_tiles_per_shard;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size)
.set_globally_allocated_address(*input.buffer());
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);

uint32_t output_cb_index = CB::c_out0;
uint32_t num_output_tiles = num_tiles_per_shard;
tt_metal::CircularBufferConfig cb_output_config =
tt_metal::CircularBufferConfig(
num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size)
.set_globally_allocated_address(*output.buffer());
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
auto [src0_cb_index, cb_src0] = create_cb(
CB::c_in0,
program,
all_cores,
input_single_tile_size,
num_tiles_per_shard,
input_cb_data_format,
input.buffer());

auto [output_cb_index, cb_output] = create_cb(
CB::c_out0,
program,
all_cores,
output_single_tile_size,
num_tiles_per_shard,
output_cb_data_format,
output.buffer());

auto src_buffer = input.buffer();

Expand Down Expand Up @@ -307,19 +297,11 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved(
uint32_t unpadded_row_size_bytes = input_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat
uint32_t padded_row_size_bytes = output_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat

uint32_t src0_cb_index = CB::c_in0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_tiles_per_row * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);
auto [src0_cb_index, cb_src0] =
create_cb(CB::c_in0, program, all_cores, input_single_tile_size, num_tiles_per_row, input_cb_data_format);

uint32_t output_cb_index = CB::c_out0;
tt_metal::CircularBufferConfig cb_output_config =
tt_metal::CircularBufferConfig(
num_tiles_per_row * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
auto [output_cb_index, cb_output] =
create_cb(CB::c_out0, program, all_cores, output_single_tile_size, num_tiles_per_row, output_cb_data_format);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down Expand Up @@ -469,48 +451,35 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_sharded(

uint32_t num_input_rows = input_shard_spec.shape[0];
uint32_t input_shard_width_bytes = input_shard_spec.shape[1] * a.element_size();
uint32_t input_shard_size_bytes = num_input_rows * input_shard_width_bytes;
uint32_t ntiles_per_core = output_shard_spec.shape[0] * output_shard_spec.shape[1] / TILE_HW;
uint32_t ntiles_per_batch = ntiles_per_core / num_batches;
uint32_t ntiles_per_block = output_shard_spec.shape[1] / TILE_WIDTH;
uint32_t nblocks_per_core = output_shard_spec.shape[0] / TILE_HEIGHT;
uint32_t num_padded_rows = output.get_legacy_shape()[-2] - a.get_legacy_shape()[-2];

uint32_t src0_cb_index = CB::c_in1;
auto [src0_cb_index, cb_src0] = create_cb(
CB::c_in1,
program,
all_cores,
input_shard_width_bytes,
num_input_rows,
input_cb_data_format,
src_sharded ? a.buffer() : nullptr);

auto [src1_cb_index, cb_src1] =
create_cb(CB::c_in0, program, all_cores, input_single_tile_size, ntiles_per_batch * 2, input_cb_data_format);

tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(input_shard_size_bytes, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_shard_width_bytes);
if (src_sharded) {
src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer());
}
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);

uint32_t src1_cb_index = CB::c_in0;
uint32_t num_padded_input_tiles = ntiles_per_batch * 2;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(
num_padded_input_tiles * input_single_tile_size, {{src1_cb_index, input_cb_data_format}})
.set_page_size(src1_cb_index, input_single_tile_size);

auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config);

uint32_t src2_cb_index = CB::c_in2;
tt_metal::CircularBufferConfig src2_cb_config =
tt_metal::CircularBufferConfig(1 * input_shard_width_bytes, {{src2_cb_index, input_cb_data_format}})
.set_page_size(src2_cb_index, input_shard_width_bytes);

auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, src2_cb_config);

uint32_t output_cb_index = CB::c_out0;
tt_metal::CircularBufferConfig output_cb_config =
tt_metal::CircularBufferConfig(
ntiles_per_core * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
if (out_sharded) {
output_cb_config.set_globally_allocated_address(*output.buffer());
}
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);
auto [src2_cb_index, cb_src2] =
create_cb(CB::c_in2, program, all_cores, input_shard_width_bytes, 1, input_cb_data_format);

auto [output_cb_index, cb_output] = create_cb(
CB::c_out0,
program,
all_cores,
output_single_tile_size,
ntiles_per_core,
output_cb_data_format,
out_sharded ? output.buffer() : nullptr);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <math.h>

#include "tt_dnn/op_library/cb_utils.hpp"
#include "tt_dnn/op_library/math.hpp"
#include "tt_dnn/op_library/untilize/untilize_op.hpp"
#include "tt_dnn/op_library/work_split_tilize.hpp"
Expand Down Expand Up @@ -88,27 +89,25 @@ operation::ProgramWithCallbacks untilize_multi_core(
end_core = (*shard_spec.grid.ranges().begin()).end;
}

uint32_t src0_cb_index = CB::c_in0;
uint32_t num_input_tiles = src_sharded ? ntiles_per_block * nblocks_per_core : ntiles_per_block * 2;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size);
if (src_sharded) {
src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer());
}
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);
auto [src0_cb_index, cb_src0] = create_cb(
CB::c_in0,
program,
all_cores,
input_single_tile_size,
num_input_tiles,
input_cb_data_format,
src_sharded ? a.buffer() : nullptr);

uint32_t output_cb_index = CB::c_out0;
uint32_t num_output_tiles = out_sharded ? ntiles_per_block * nblocks_per_core : ntiles_per_block * 2;
tt_metal::CircularBufferConfig output_cb_config =
tt_metal::CircularBufferConfig(
num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
if (out_sharded) {
output_cb_config = output_cb_config.set_globally_allocated_address(*output.buffer());
}
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);
auto [output_cb_index, cb_output] = create_cb(
CB::c_out0,
program,
all_cores,
output_single_tile_size,
num_output_tiles,
output_cb_data_format,
out_sharded ? output.buffer() : nullptr);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down Expand Up @@ -459,19 +458,8 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved(
uint32_t padded_row_size_bytes = input_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat
uint32_t unpadded_row_size_bytes = output_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat

uint32_t src0_cb_index = CB::c_in0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_tiles_per_row * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);

uint32_t output_cb_index = CB::c_out0;
tt_metal::CircularBufferConfig cb_output_config =
tt_metal::CircularBufferConfig(
num_tiles_per_row * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);
create_cb(CB::c_in0, program, all_cores, input_single_tile_size, num_tiles_per_row, input_cb_data_format);
create_cb(CB::c_out0, program, all_cores, output_single_tile_size, num_tiles_per_row, output_cb_data_format);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down Expand Up @@ -666,35 +654,30 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded(
if (!row_major) {
std::swap(end_core.x, end_core.y);
}
uint32_t src0_cb_index = CB::c_in0;

uint32_t num_input_tiles = ntiles_per_block * nblocks_per_core;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(
num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}})
.set_page_size(src0_cb_index, input_single_tile_size);
if (src_sharded) {
src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer());
}
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config);
auto [src0_cb_index, cb_src0] = create_cb(
CB::c_in0,
program,
all_cores,
input_single_tile_size,
num_input_tiles,
input_cb_data_format,
src_sharded ? a.buffer() : nullptr);

uint32_t output_cb_index = CB::c_out0;
uint32_t num_output_tiles = out_sharded ? ntiles_per_batch * 2 : ntiles_per_block * 2;
tt_metal::CircularBufferConfig output_cb_config =
tt_metal::CircularBufferConfig(
num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}})
.set_page_size(output_cb_index, output_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);

CBHandle cb_sharded_output = 0;
uint32_t sharded_output_cb_index = CB::c_out1;
if (out_sharded) {
tt_metal::CircularBufferConfig sharded_output_cb_config =
tt_metal::CircularBufferConfig(
num_output_rows_unpadded * block_row_size, {{sharded_output_cb_index, output_cb_data_format}})
.set_page_size(sharded_output_cb_index, block_row_size)
.set_globally_allocated_address(*output.buffer());
cb_sharded_output = tt_metal::CreateCircularBuffer(program, all_cores, sharded_output_cb_config);
}
auto [output_cb_index, cb_output] =
create_cb(CB::c_out0, program, all_cores, output_single_tile_size, num_output_tiles, output_cb_data_format);

auto [sharded_output_cb_index, cb_sharded_output] = out_sharded ? create_cb(
CB::c_out1,
program,
all_cores,
block_row_size,
num_output_rows_unpadded,
output_cb_data_format,
output.buffer())
: std::make_tuple(CB::c_out1, CBHandle{});

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
Expand Down

0 comments on commit ab7e272

Please sign in to comment.