-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#3900: Add prod support for batch and channels
- Loading branch information
Showing
7 changed files
with
375 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
|
||
#include "compute_kernel_api/eltwise_binary.h" | ||
#include "compute_kernel_api/reduce.h" | ||
#include "compute_kernel_api/tile_move_copy.h" | ||
|
||
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } | ||
ALWI void REL() { release_dst(tt::DstMode::Half); } | ||
|
||
namespace NAMESPACE { | ||
void MAIN { | ||
constexpr int onetile = 1; | ||
uint32_t per_core_block_cnt = get_arg_val<uint32_t>(0); | ||
binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1); | ||
bool enable_reload = false; | ||
for (uint32_t block = 0; block < per_core_block_cnt; ++block) { | ||
bool last_out = block == (per_core_block_cnt - 1); | ||
|
||
// elemwise-mul | ||
ACQ(); | ||
cb_wait_front(tt::CB::c_in0, onetile); | ||
cb_wait_front(tt::CB::c_in1, onetile); | ||
|
||
cb_reserve_back(tt::CB::c_intermed0, onetile); | ||
mul_tiles_init(); | ||
// dst0 = c_in0 x c_in1 | ||
mul_tiles(tt::CB::c_in0, tt::CB::c_in1, 0, 0, 0); | ||
// c_intermed0 = pack(dst0) | ||
pack_tile(0, tt::CB::c_intermed0); | ||
cb_push_back(tt::CB::c_intermed0, onetile); | ||
|
||
cb_pop_front(tt::CB::c_in0, onetile); | ||
cb_pop_front(tt::CB::c_in1, onetile); | ||
REL(); | ||
|
||
// reduce-w | ||
ACQ(); | ||
if (enable_reload) { | ||
cb_wait_front(tt::CB::c_intermed1, onetile); | ||
copy_tile_to_dst_init_short(); | ||
copy_tile(tt::CB::c_intermed1, 0, 0); | ||
cb_pop_front(tt::CB::c_intermed1, onetile); | ||
} | ||
|
||
if (last_out) { | ||
cb_reserve_back(tt::CB::c_out0, onetile); | ||
pack_tile(0, tt::CB::c_out0); | ||
cb_push_back(tt::CB::c_out0, onetile); | ||
} else { | ||
cb_reserve_back(tt::CB::c_intermed1, onetile); | ||
pack_tile(0, tt::CB::c_intermed1); | ||
cb_push_back(tt::CB::c_intermed1, onetile); | ||
} | ||
REL(); | ||
enable_reload = true; | ||
} | ||
} | ||
} // namespace NAMESPACE |
114 changes: 114 additions & 0 deletions
114
tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
|
||
#include "dataflow_api.h" | ||
|
||
void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) { | ||
union { | ||
float f; | ||
uint32_t u; | ||
} zero; | ||
zero.f = 0.0f; | ||
auto ptr = reinterpret_cast<uint16_t *>(l1_addr); | ||
for (uint32_t h = 0; h < 16; h++) { | ||
// sub tile 0 | ||
{ | ||
uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w; | ||
uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h; | ||
uint32_t w = (h >= mask_h_0) ? 0 : mask_w_0; | ||
for (; w < 16; w++) { | ||
ptr[h * 16 + w] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
// sub tile 1 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h; | ||
uint32_t w = (h >= mask_h_0) ? 0 : mask_w_1; | ||
for (; w < 16; w++) { | ||
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
// sub tile 2 | ||
{ | ||
uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w; | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t w = (h >= mask_h_1) ? 0 : mask_w_0; | ||
for (; w < 16; w++) { | ||
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
// sub tile 3 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t w = (h >= mask_h_1) ? 0 : mask_w_1; | ||
for (; w < 16; w++) { | ||
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void kernel_main() { | ||
// same arg indices as in reader_binary_diff_lenghts for compat | ||
uint32_t src0_addr = get_arg_val<uint32_t>(0); | ||
uint32_t src1_addr = get_arg_val<uint32_t>(1); | ||
uint32_t num_tiles = get_arg_val<uint32_t>(2); | ||
uint32_t start_id = get_arg_val<uint32_t>(3); | ||
uint32_t mask_h = get_arg_val<uint32_t>(4); | ||
uint32_t mask_w = get_arg_val<uint32_t>(5); | ||
|
||
constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; | ||
constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1; | ||
constexpr uint32_t scaler = get_compile_time_arg_val(2); | ||
|
||
constexpr uint32_t cb_id_in0 = 0; | ||
constexpr uint32_t cb_id_in1 = 1; | ||
constexpr uint32_t cb_id_in2 = 2; | ||
cb_reserve_back(cb_id_in2, 1); | ||
if (scaler != 0) { | ||
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id_in2)); | ||
for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0); | ||
|
||
for (int k = 0; k < 4; k++) | ||
for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16); | ||
} | ||
cb_push_back(cb_id_in2, 1); | ||
|
||
uint32_t l1_write_addr_in0; | ||
uint32_t src0_tile_bytes = get_tile_size(cb_id_in0); | ||
DataFormat src0_data_format = get_dataformat(cb_id_in0); | ||
const InterleavedAddrGenFast<src0_is_dram> s0 = { | ||
.bank_base_address = src0_addr, .page_size = src0_tile_bytes, .data_format = src0_data_format}; | ||
uint32_t l1_write_addr_in1; | ||
uint32_t src1_tile_bytes = get_tile_size(cb_id_in1); | ||
DataFormat src1_data_format = get_dataformat(cb_id_in1); | ||
const InterleavedAddrGenFast<src1_is_dram> s1 = { | ||
.bank_base_address = src1_addr, .page_size = src1_tile_bytes, .data_format = src1_data_format}; | ||
|
||
constexpr uint32_t onetile = 1; | ||
for (uint32_t i = start_id; i < start_id + num_tiles; i++) { | ||
bool last_tile = i == (start_id + num_tiles - 1); | ||
cb_reserve_back(cb_id_in0, onetile); | ||
l1_write_addr_in0 = get_write_ptr(cb_id_in0); | ||
noc_async_read_tile(i, s0, l1_write_addr_in0); | ||
|
||
cb_reserve_back(cb_id_in1, onetile); | ||
l1_write_addr_in1 = get_write_ptr(cb_id_in1); | ||
noc_async_read_tile(i, s1, l1_write_addr_in1); | ||
|
||
noc_async_read_barrier(); | ||
|
||
if (last_tile) { | ||
mask_tile_in_reader(l1_write_addr_in0, mask_w, mask_h); | ||
mask_tile_in_reader(l1_write_addr_in1, mask_w, mask_h); | ||
} | ||
|
||
cb_push_back(cb_id_in0, onetile); | ||
cb_push_back(cb_id_in1, onetile); | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
uint32_t dst_addr = get_arg_val<uint32_t>(0); | ||
uint32_t num_tiles = get_arg_val<uint32_t>(1); | ||
uint32_t start_id = get_arg_val<uint32_t>(2); | ||
|
||
constexpr uint32_t cb_id_out = get_compile_time_arg_val(0); | ||
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; | ||
|
||
// single-tile ublocks | ||
constexpr uint32_t onetile = 1; | ||
const uint32_t tile_bytes = get_tile_size(cb_id_out); | ||
const DataFormat data_format = get_dataformat(cb_id_out); | ||
|
||
const InterleavedAddrGenFast<dst_is_dram> s = { | ||
.bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; | ||
|
||
uint32_t end_id = start_id + num_tiles; | ||
for (uint32_t i = start_id; i < end_id; i++) { | ||
cb_wait_front(cb_id_out, onetile); | ||
uint32_t l1_read_addr = get_read_ptr(cb_id_out); | ||
noc_async_write_tile(i, s, l1_read_addr); | ||
noc_async_write_barrier(); | ||
cb_pop_front(cb_id_out, onetile); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp" | ||
#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/detail/util.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
|
||
using namespace tt::constants; | ||
|
||
namespace tt { | ||
|
||
namespace operations { | ||
|
||
namespace primary { | ||
|
||
operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Tensor &b, Tensor &output) { | ||
Program program{}; | ||
CoreCoord core = {0, 0}; | ||
const uint32_t core_num = 1; | ||
|
||
DataFormat cb_data_format = datatype_to_dataformat_converter(output.dtype()); | ||
uint32_t single_tile_size = detail::TileSize(cb_data_format); | ||
|
||
tt_metal::Buffer *src0_buffer = a.buffer(); | ||
tt_metal::Buffer *src1_buffer = b.buffer(); | ||
|
||
uint32_t num_tiles = a.volume() / TILE_HW; | ||
float scaler = 1.0f; | ||
const auto &a_shape_wo_padding = a.shape().without_padding(); | ||
uint32_t pad_h = a_shape_wo_padding[2] % TILE_HEIGHT; | ||
uint32_t pad_w = a_shape_wo_padding[3] % TILE_WIDTH; | ||
uint32_t mask_h = (pad_h == 0) ? (TILE_HEIGHT) : (pad_h); | ||
uint32_t mask_w = (pad_w == 0) ? (TILE_WIDTH) : (pad_w); | ||
|
||
// This should allocate a DRAM buffer on the device | ||
tt_metal::Device *device = a.device(); | ||
tt_metal::Buffer *dst_buffer = output.buffer(); | ||
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); | ||
|
||
//////////////////////////////////////////////////////////////////////////// | ||
// CircularBuffer Setup | ||
//////////////////////////////////////////////////////////////////////////// | ||
const uint32_t in0_t = 2; // a | ||
const uint32_t in1_t = 2; // b | ||
const uint32_t in2_t = 1; // scaler | ||
const uint32_t out0_t = 2; // out | ||
const uint32_t im0_t = 1; | ||
const uint32_t im1_t = 1; | ||
|
||
CreateCircularBuffer( | ||
program, | ||
std::set<CoreRange>{CoreRange{.start = core, .end = core}}, | ||
cb_data_format, | ||
{ | ||
{CB::c_in0, in0_t}, | ||
{CB::c_in1, in1_t}, | ||
{CB::c_in2, in2_t}, | ||
{CB::c_out0, out0_t}, | ||
{CB::c_intermed0, im0_t}, | ||
{CB::c_intermed1, im1_t}, | ||
}); | ||
|
||
//////////////////////////////////////////////////////////////////////////// | ||
// DataMovementKernel SetUp | ||
//////////////////////////////////////////////////////////////////////////// | ||
std::vector<uint32_t> reader_compile_time_args = { | ||
(std::uint32_t)is_dram(src0_buffer), | ||
(std::uint32_t)is_dram(src1_buffer), | ||
*reinterpret_cast<uint32_t *>(&scaler)}; | ||
|
||
std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)CB::c_out0, (std::uint32_t)is_dram(dst_buffer)}; | ||
|
||
const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp"; | ||
const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp"; | ||
|
||
const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, core, reader_compile_time_args); | ||
const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, core, writer_compile_time_args); | ||
|
||
//////////////////////////////////////////////////////////////////////////// | ||
// ComputeKernel SetUp | ||
//////////////////////////////////////////////////////////////////////////// | ||
vector<uint32_t> compute_kernel_args = {}; | ||
std::map<string, string> compute_defines; | ||
|
||
const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp"; | ||
const auto compute_kernel_id = | ||
CreateComputeKernel(program, compute_kernel_file, {core, core_num, compute_kernel_args}, compute_defines); | ||
|
||
//////////////////////////////////////////////////////////////////////////// | ||
// RuntimeArgs SetUp | ||
//////////////////////////////////////////////////////////////////////////// | ||
SetRuntimeArgs( | ||
program, | ||
reader_kernel_id, | ||
core, | ||
{src0_buffer->address(), src1_buffer->address(), num_tiles, 0, mask_h, mask_w}); | ||
SetRuntimeArgs(program, compute_kernel_id, core, {num_tiles, 1}); | ||
SetRuntimeArgs(program, writer_kernel_id, core, {output.buffer()->address(), 1, 0}); | ||
|
||
auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, compute_kernel_id]( | ||
const void *operation, | ||
const Program &program, | ||
const std::vector<Tensor> &input_tensors, | ||
const std::vector<std::optional<const Tensor>> &, | ||
const std::vector<Tensor> &output_tensors) { | ||
auto src_buffer_a = input_tensors.at(0).buffer(); | ||
auto src_buffer_b = input_tensors.at(1).buffer(); | ||
|
||
auto dst_buffer = output_tensors.at(0).buffer(); | ||
|
||
CoreCoord core = {0, 0}; | ||
|
||
uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW; | ||
|
||
{ | ||
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); | ||
runtime_args[0] = src_buffer_a->address(); | ||
runtime_args[1] = src_buffer_b->address(); | ||
runtime_args[2] = num_tiles; | ||
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); | ||
} | ||
|
||
{ | ||
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); | ||
runtime_args[0] = num_tiles; | ||
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); | ||
} | ||
|
||
{ | ||
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); | ||
runtime_args[0] = dst_buffer->address(); | ||
runtime_args[1] = 1; | ||
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); | ||
} | ||
}; | ||
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; | ||
} | ||
|
||
} // namespace primary | ||
} // namespace operations | ||
} // namespace tt |
Oops, something went wrong.