Skip to content

Commit

Permalink
#3900: Add prod support for batch and channels
Browse files Browse the repository at this point in the history
  • Loading branch information
Muthu authored and ruthreshx committed Feb 2, 2024
1 parent d43c1a5 commit 3da2e49
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 6 deletions.
11 changes: 10 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,21 @@ def get_tensors(input_shape, output_shape, device):
[
1,
],
[
2,
],
[
3,
],
),
ids=["0", "1"],
ids=["0", "1", "2", "3"],
)
def test_moreh_prod_dims(input_shape, dims, device):
output_shape = input_shape.copy()

if dims[0] in [2, 3]:
pytest.skip(f"Dim {dims[0]} not supported at this time.")

for dim in dims:
output_shape[dim] = 1

Expand Down
62 changes: 62 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/tile_move_copy.h"

ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }

namespace NAMESPACE {
void MAIN {
constexpr int onetile = 1;
uint32_t per_core_block_cnt = get_arg_val<uint32_t>(0);
binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1);
bool enable_reload = false;
for (uint32_t block = 0; block < per_core_block_cnt; ++block) {
bool last_out = block == (per_core_block_cnt - 1);

// elemwise-mul
ACQ();
cb_wait_front(tt::CB::c_in0, onetile);
cb_wait_front(tt::CB::c_in1, onetile);

cb_reserve_back(tt::CB::c_intermed0, onetile);
mul_tiles_init();
// dst0 = c_in0 x c_in1
mul_tiles(tt::CB::c_in0, tt::CB::c_in1, 0, 0, 0);
// c_intermed0 = pack(dst0)
pack_tile(0, tt::CB::c_intermed0);
cb_push_back(tt::CB::c_intermed0, onetile);

cb_pop_front(tt::CB::c_in0, onetile);
cb_pop_front(tt::CB::c_in1, onetile);
REL();

// reduce-w
ACQ();
if (enable_reload) {
cb_wait_front(tt::CB::c_intermed1, onetile);
copy_tile_to_dst_init_short();
copy_tile(tt::CB::c_intermed1, 0, 0);
cb_pop_front(tt::CB::c_intermed1, onetile);
}

if (last_out) {
cb_reserve_back(tt::CB::c_out0, onetile);
pack_tile(0, tt::CB::c_out0);
cb_push_back(tt::CB::c_out0, onetile);
} else {
cb_reserve_back(tt::CB::c_intermed1, onetile);
pack_tile(0, tt::CB::c_intermed1);
cb_push_back(tt::CB::c_intermed1, onetile);
}
REL();
enable_reload = true;
}
}
} // namespace NAMESPACE
114 changes: 114 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>

#include "dataflow_api.h"

void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) {
union {
float f;
uint32_t u;
} zero;
zero.f = 0.0f;
auto ptr = reinterpret_cast<uint16_t *>(l1_addr);
for (uint32_t h = 0; h < 16; h++) {
// sub tile 0
{
uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w;
uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h;
uint32_t w = (h >= mask_h_0) ? 0 : mask_w_0;
for (; w < 16; w++) {
ptr[h * 16 + w] = uint16_t(zero.u >> 16);
}
}
// sub tile 1
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h;
uint32_t w = (h >= mask_h_0) ? 0 : mask_w_1;
for (; w < 16; w++) {
ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16);
}
}
// sub tile 2
{
uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w;
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t w = (h >= mask_h_1) ? 0 : mask_w_0;
for (; w < 16; w++) {
ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16);
}
}
// sub tile 3
{
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16;
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16;
uint32_t w = (h >= mask_h_1) ? 0 : mask_w_1;
for (; w < 16; w++) {
ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16);
}
}
}
}

void kernel_main() {
// same arg indices as in reader_binary_diff_lenghts for compat
uint32_t src0_addr = get_arg_val<uint32_t>(0);
uint32_t src1_addr = get_arg_val<uint32_t>(1);
uint32_t num_tiles = get_arg_val<uint32_t>(2);
uint32_t start_id = get_arg_val<uint32_t>(3);
uint32_t mask_h = get_arg_val<uint32_t>(4);
uint32_t mask_w = get_arg_val<uint32_t>(5);

constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t scaler = get_compile_time_arg_val(2);

constexpr uint32_t cb_id_in0 = 0;
constexpr uint32_t cb_id_in1 = 1;
constexpr uint32_t cb_id_in2 = 2;
cb_reserve_back(cb_id_in2, 1);
if (scaler != 0) {
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id_in2));
for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0);

for (int k = 0; k < 4; k++)
for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16);
}
cb_push_back(cb_id_in2, 1);

uint32_t l1_write_addr_in0;
uint32_t src0_tile_bytes = get_tile_size(cb_id_in0);
DataFormat src0_data_format = get_dataformat(cb_id_in0);
const InterleavedAddrGenFast<src0_is_dram> s0 = {
.bank_base_address = src0_addr, .page_size = src0_tile_bytes, .data_format = src0_data_format};
uint32_t l1_write_addr_in1;
uint32_t src1_tile_bytes = get_tile_size(cb_id_in1);
DataFormat src1_data_format = get_dataformat(cb_id_in1);
const InterleavedAddrGenFast<src1_is_dram> s1 = {
.bank_base_address = src1_addr, .page_size = src1_tile_bytes, .data_format = src1_data_format};

constexpr uint32_t onetile = 1;
for (uint32_t i = start_id; i < start_id + num_tiles; i++) {
bool last_tile = i == (start_id + num_tiles - 1);
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
noc_async_read_tile(i, s0, l1_write_addr_in0);

cb_reserve_back(cb_id_in1, onetile);
l1_write_addr_in1 = get_write_ptr(cb_id_in1);
noc_async_read_tile(i, s1, l1_write_addr_in1);

noc_async_read_barrier();

if (last_tile) {
mask_tile_in_reader(l1_write_addr_in0, mask_w, mask_h);
mask_tile_in_reader(l1_write_addr_in1, mask_w, mask_h);
}

cb_push_back(cb_id_in0, onetile);
cb_push_back(cb_id_in1, onetile);
}
}
31 changes: 31 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"

void kernel_main() {
uint32_t dst_addr = get_arg_val<uint32_t>(0);
uint32_t num_tiles = get_arg_val<uint32_t>(1);
uint32_t start_id = get_arg_val<uint32_t>(2);

constexpr uint32_t cb_id_out = get_compile_time_arg_val(0);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;

// single-tile ublocks
constexpr uint32_t onetile = 1;
const uint32_t tile_bytes = get_tile_size(cb_id_out);
const DataFormat data_format = get_dataformat(cb_id_out);

const InterleavedAddrGenFast<dst_is_dram> s = {
.bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format};

uint32_t end_id = start_id + num_tiles;
for (uint32_t i = start_id; i < end_id; i++) {
cb_wait_front(cb_id_out, onetile);
uint32_t l1_read_addr = get_read_ptr(cb_id_out);
noc_async_write_tile(i, s, l1_read_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_out, onetile);
}
}
144 changes: 144 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/prod_nc/prod_hw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp"
#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"

using namespace tt::constants;

namespace tt {

namespace operations {

namespace primary {

operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Tensor &b, Tensor &output) {
Program program{};
CoreCoord core = {0, 0};
const uint32_t core_num = 1;

DataFormat cb_data_format = datatype_to_dataformat_converter(output.dtype());
uint32_t single_tile_size = detail::TileSize(cb_data_format);

tt_metal::Buffer *src0_buffer = a.buffer();
tt_metal::Buffer *src1_buffer = b.buffer();

uint32_t num_tiles = a.volume() / TILE_HW;
float scaler = 1.0f;
const auto &a_shape_wo_padding = a.shape().without_padding();
uint32_t pad_h = a_shape_wo_padding[2] % TILE_HEIGHT;
uint32_t pad_w = a_shape_wo_padding[3] % TILE_WIDTH;
uint32_t mask_h = (pad_h == 0) ? (TILE_HEIGHT) : (pad_h);
uint32_t mask_w = (pad_w == 0) ? (TILE_WIDTH) : (pad_w);

// This should allocate a DRAM buffer on the device
tt_metal::Device *device = a.device();
tt_metal::Buffer *dst_buffer = output.buffer();
TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!");

////////////////////////////////////////////////////////////////////////////
// CircularBuffer Setup
////////////////////////////////////////////////////////////////////////////
const uint32_t in0_t = 2; // a
const uint32_t in1_t = 2; // b
const uint32_t in2_t = 1; // scaler
const uint32_t out0_t = 2; // out
const uint32_t im0_t = 1;
const uint32_t im1_t = 1;

CreateCircularBuffer(
program,
std::set<CoreRange>{CoreRange{.start = core, .end = core}},
cb_data_format,
{
{CB::c_in0, in0_t},
{CB::c_in1, in1_t},
{CB::c_in2, in2_t},
{CB::c_out0, out0_t},
{CB::c_intermed0, im0_t},
{CB::c_intermed1, im1_t},
});

////////////////////////////////////////////////////////////////////////////
// DataMovementKernel SetUp
////////////////////////////////////////////////////////////////////////////
std::vector<uint32_t> reader_compile_time_args = {
(std::uint32_t)is_dram(src0_buffer),
(std::uint32_t)is_dram(src1_buffer),
*reinterpret_cast<uint32_t *>(&scaler)};

std::vector<uint32_t> writer_compile_time_args = {(std::uint32_t)CB::c_out0, (std::uint32_t)is_dram(dst_buffer)};

const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp";
const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp";

const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, core, reader_compile_time_args);
const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, core, writer_compile_time_args);

////////////////////////////////////////////////////////////////////////////
// ComputeKernel SetUp
////////////////////////////////////////////////////////////////////////////
vector<uint32_t> compute_kernel_args = {};
std::map<string, string> compute_defines;

const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp";
const auto compute_kernel_id =
CreateComputeKernel(program, compute_kernel_file, {core, core_num, compute_kernel_args}, compute_defines);

////////////////////////////////////////////////////////////////////////////
// RuntimeArgs SetUp
////////////////////////////////////////////////////////////////////////////
SetRuntimeArgs(
program,
reader_kernel_id,
core,
{src0_buffer->address(), src1_buffer->address(), num_tiles, 0, mask_h, mask_w});
SetRuntimeArgs(program, compute_kernel_id, core, {num_tiles, 1});
SetRuntimeArgs(program, writer_kernel_id, core, {output.buffer()->address(), 1, 0});

auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, compute_kernel_id](
const void *operation,
const Program &program,
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &,
const std::vector<Tensor> &output_tensors) {
auto src_buffer_a = input_tensors.at(0).buffer();
auto src_buffer_b = input_tensors.at(1).buffer();

auto dst_buffer = output_tensors.at(0).buffer();

CoreCoord core = {0, 0};

uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW;

{
auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);
runtime_args[0] = src_buffer_a->address();
runtime_args[1] = src_buffer_b->address();
runtime_args[2] = num_tiles;
SetRuntimeArgs(program, reader_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core);
runtime_args[0] = num_tiles;
SetRuntimeArgs(program, compute_kernel_id, core, runtime_args);
}

{
auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);
runtime_args[0] = dst_buffer->address();
runtime_args[1] = 1;
SetRuntimeArgs(program, writer_kernel_id, core, runtime_args);
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
}

} // namespace primary
} // namespace operations
} // namespace tt
Loading

0 comments on commit 3da2e49

Please sign in to comment.