diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py index 4be0522e832..bdd368ae0e5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_prod.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py @@ -53,12 +53,21 @@ def get_tensors(input_shape, output_shape, device): [ 1, ], + [ + 2, + ], + [ + 3, + ], ), - ids=["0", "1"], + ids=["0", "1", "2", "3"], ) def test_moreh_prod_dims(input_shape, dims, device): output_shape = input_shape.copy() + if dims[0] in [2, 3]: + pytest.skip(f"Dim {dims[0]} not supported at this time.") + for dim in dims: output_shape[dim] = 1 diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp new file mode 100644 index 00000000000..5e94c03117d --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/tile_move_copy.h" + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + +namespace NAMESPACE { +void MAIN { + constexpr int onetile = 1; + uint32_t per_core_block_cnt = get_arg_val(0); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1); + bool enable_reload = false; + for (uint32_t block = 0; block < per_core_block_cnt; ++block) { + bool last_out = block == (per_core_block_cnt - 1); + + // elemwise-mul + ACQ(); + cb_wait_front(tt::CB::c_in0, onetile); + cb_wait_front(tt::CB::c_in1, onetile); + + cb_reserve_back(tt::CB::c_intermed0, onetile); + mul_tiles_init(); + // dst0 = c_in0 x c_in1 + mul_tiles(tt::CB::c_in0, tt::CB::c_in1, 0, 0, 0); + // c_intermed0 = pack(dst0) + pack_tile(0, tt::CB::c_intermed0); + cb_push_back(tt::CB::c_intermed0, onetile); + + cb_pop_front(tt::CB::c_in0, onetile); + cb_pop_front(tt::CB::c_in1, onetile); + REL(); + + // reduce-w + ACQ(); + if (enable_reload) { + cb_wait_front(tt::CB::c_intermed1, onetile); + copy_tile_to_dst_init_short(); + copy_tile(tt::CB::c_intermed1, 0, 0); + cb_pop_front(tt::CB::c_intermed1, onetile); + } + + if (last_out) { + cb_reserve_back(tt::CB::c_out0, onetile); + pack_tile(0, tt::CB::c_out0); + cb_push_back(tt::CB::c_out0, onetile); + } else { + cb_reserve_back(tt::CB::c_intermed1, onetile); + pack_tile(0, tt::CB::c_intermed1); + cb_push_back(tt::CB::c_intermed1, onetile); + } + REL(); + enable_reload = true; + } +} +} // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp new file mode 100644 index 00000000000..5e0ca29daad --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) { + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + auto ptr = reinterpret_cast(l1_addr); + for (uint32_t h = 0; h < 16; h++) { + // sub tile 0 + { + uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w; + uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h; + uint32_t w = (h >= mask_h_0) ? 0 : mask_w_0; + for (; w < 16; w++) { + ptr[h * 16 + w] = uint16_t(zero.u >> 16); + } + } + // sub tile 1 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t mask_h_0 = (mask_h >= 16) ? 16 : mask_h; + uint32_t w = (h >= mask_h_0) ? 0 : mask_w_1; + for (; w < 16; w++) { + ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); + } + } + // sub tile 2 + { + uint32_t mask_w_0 = (mask_w >= 16) ? 16 : mask_w; + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t w = (h >= mask_h_1) ? 0 : mask_w_0; + for (; w < 16; w++) { + ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); + } + } + // sub tile 3 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t w = (h >= mask_h_1) ? 0 : mask_w_1; + for (; w < 16; w++) { + ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); + } + } + } +} + +void kernel_main() { + // same arg indices as in reader_binary_diff_lenghts for compat + uint32_t src0_addr = get_arg_val(0); + uint32_t src1_addr = get_arg_val(1); + uint32_t num_tiles = get_arg_val(2); + uint32_t start_id = get_arg_val(3); + uint32_t mask_h = get_arg_val(4); + uint32_t mask_w = get_arg_val(5); + + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1; + constexpr uint32_t scaler = get_compile_time_arg_val(2); + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + constexpr uint32_t cb_id_in2 = 2; + cb_reserve_back(cb_id_in2, 1); + if (scaler != 0) { + auto ptr = reinterpret_cast(get_write_ptr(cb_id_in2)); + for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16); + } + cb_push_back(cb_id_in2, 1); + + uint32_t l1_write_addr_in0; + uint32_t src0_tile_bytes = get_tile_size(cb_id_in0); + DataFormat src0_data_format = get_dataformat(cb_id_in0); + const InterleavedAddrGenFast s0 = { + .bank_base_address = src0_addr, .page_size = src0_tile_bytes, .data_format = src0_data_format}; + uint32_t l1_write_addr_in1; + uint32_t src1_tile_bytes = get_tile_size(cb_id_in1); + DataFormat src1_data_format = get_dataformat(cb_id_in1); + const InterleavedAddrGenFast s1 = { + .bank_base_address = src1_addr, .page_size = src1_tile_bytes, .data_format = src1_data_format}; + + constexpr uint32_t onetile = 1; + for (uint32_t i = start_id; i < start_id + num_tiles; i++) { + bool last_tile = i == (start_id + num_tiles - 1); + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(i, s0, l1_write_addr_in0); + + cb_reserve_back(cb_id_in1, onetile); + l1_write_addr_in1 = get_write_ptr(cb_id_in1); + noc_async_read_tile(i, s1, l1_write_addr_in1); + + noc_async_read_barrier(); + + if (last_tile) { + mask_tile_in_reader(l1_write_addr_in0, mask_w, mask_h); + mask_tile_in_reader(l1_write_addr_in1, mask_w, mask_h); + } + + cb_push_back(cb_id_in0, onetile); + cb_push_back(cb_id_in1, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp new file mode 100644 index 00000000000..5669308c229 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t dst_addr = get_arg_val(0); + uint32_t num_tiles = get_arg_val(1); + uint32_t start_id = get_arg_val(2); + + constexpr uint32_t cb_id_out = get_compile_time_arg_val(0); + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + + // single-tile ublocks + constexpr uint32_t onetile = 1; + const uint32_t tile_bytes = get_tile_size(cb_id_out); + const DataFormat data_format = get_dataformat(cb_id_out); + + const InterleavedAddrGenFast s = { + .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; + + uint32_t end_id = start_id + num_tiles; + for (uint32_t i = start_id; i < end_id; i++) { + cb_wait_front(cb_id_out, onetile); + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write_tile(i, s, l1_read_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_out, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_hw.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_hw.cpp new file mode 100644 index 00000000000..c4253ec2377 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_hw.cpp @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; + +namespace tt { + +namespace operations { + +namespace primary { + +operation::ProgramWithCallbacks moreh_dot_single_core(const Tensor &a, const Tensor &b, Tensor &output) { + Program program{}; + CoreCoord core = {0, 0}; + const uint32_t core_num = 1; + + DataFormat cb_data_format = datatype_to_dataformat_converter(output.dtype()); + uint32_t single_tile_size = detail::TileSize(cb_data_format); + + tt_metal::Buffer *src0_buffer = a.buffer(); + tt_metal::Buffer *src1_buffer = b.buffer(); + + uint32_t num_tiles = a.volume() / TILE_HW; + float scaler = 1.0f; + const auto &a_shape_wo_padding = a.shape().without_padding(); + uint32_t pad_h = a_shape_wo_padding[2] % TILE_HEIGHT; + uint32_t pad_w = a_shape_wo_padding[3] % TILE_WIDTH; + uint32_t mask_h = (pad_h == 0) ? (TILE_HEIGHT) : (pad_h); + uint32_t mask_w = (pad_w == 0) ? (TILE_WIDTH) : (pad_w); + + // This should allocate a DRAM buffer on the device + tt_metal::Device *device = a.device(); + tt_metal::Buffer *dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + //////////////////////////////////////////////////////////////////////////// + // CircularBuffer Setup + //////////////////////////////////////////////////////////////////////////// + const uint32_t in0_t = 2; // a + const uint32_t in1_t = 2; // b + const uint32_t in2_t = 1; // scaler + const uint32_t out0_t = 2; // out + const uint32_t im0_t = 1; + const uint32_t im1_t = 1; + + CreateCircularBuffer( + program, + std::set{CoreRange{.start = core, .end = core}}, + cb_data_format, + { + {CB::c_in0, in0_t}, + {CB::c_in1, in1_t}, + {CB::c_in2, in2_t}, + {CB::c_out0, out0_t}, + {CB::c_intermed0, im0_t}, + {CB::c_intermed1, im1_t}, + }); + + //////////////////////////////////////////////////////////////////////////// + // DataMovementKernel SetUp + //////////////////////////////////////////////////////////////////////////// + std::vector reader_compile_time_args = { + (std::uint32_t)is_dram(src0_buffer), + (std::uint32_t)is_dram(src1_buffer), + *reinterpret_cast(&scaler)}; + + std::vector writer_compile_time_args = {(std::uint32_t)CB::c_out0, (std::uint32_t)is_dram(dst_buffer)}; + + const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp"; + const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_hw.cpp"; + + const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, core, reader_compile_time_args); + const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, core, writer_compile_time_args); + + //////////////////////////////////////////////////////////////////////////// + // ComputeKernel SetUp + //////////////////////////////////////////////////////////////////////////// + vector compute_kernel_args = {}; + std::map compute_defines; + + const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp"; + const auto compute_kernel_id = + CreateComputeKernel(program, compute_kernel_file, {core, core_num, compute_kernel_args}, compute_defines); + + //////////////////////////////////////////////////////////////////////////// + // RuntimeArgs SetUp + //////////////////////////////////////////////////////////////////////////// + SetRuntimeArgs( + program, + reader_kernel_id, + core, + {src0_buffer->address(), src1_buffer->address(), num_tiles, 0, mask_h, mask_w}); + SetRuntimeArgs(program, compute_kernel_id, core, {num_tiles, 1}); + SetRuntimeArgs(program, writer_kernel_id, core, {output.buffer()->address(), 1, 0}); + + auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, compute_kernel_id]( + const void *operation, + const Program &program, + const std::vector &input_tensors, + const std::vector> &, + const std::vector &output_tensors) { + auto src_buffer_a = input_tensors.at(0).buffer(); + auto src_buffer_b = input_tensors.at(1).buffer(); + + auto dst_buffer = output_tensors.at(0).buffer(); + + CoreCoord core = {0, 0}; + + uint32_t num_tiles = input_tensors.at(0).volume() / TILE_HW; + + { + auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer_a->address(); + runtime_args[1] = src_buffer_b->address(); + runtime_args[2] = num_tiles; + SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); + } + + { + auto runtime_args = GetRuntimeArgs(program, compute_kernel_id, core); + runtime_args[0] = num_tiles; + SetRuntimeArgs(program, compute_kernel_id, core, runtime_args); + } + + { + auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + runtime_args[1] = 1; + SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); + } + }; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp index 8861313934f..679b72fdc50 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp @@ -22,6 +22,7 @@ void Prod::validate(const std::vector& inputs) const { const auto& output = inputs.at(1); auto input_shape = input.shape(); + TT_ASSERT((input_shape.rank()), "rank should be 4"); const auto& output_shape = output.shape(); auto input_shape_wo_padding = input.shape().without_padding(); const auto& output_shape_wo_padding = output.shape().without_padding(); @@ -32,8 +33,8 @@ void Prod::validate(const std::vector& inputs) const { } for (int i = 0; i < input_shape.rank(); ++i) { - TT_ASSERT(input_shape[i] == output_shape[i]); - TT_ASSERT(input_shape_wo_padding[i] == output_shape_wo_padding[i]); + TT_FATAL(input_shape[i] == output_shape[i]); + TT_FATAL(input_shape_wo_padding[i] == output_shape_wo_padding[i]); } } @@ -49,11 +50,15 @@ std::vector Prod::compute_output_shapes(const std::vector& inputs operation::ProgramWithCallbacks Prod::create_program( const std::vector& inputs, std::vector& outputs) const { - TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); auto& input = inputs.at(0); auto& output = inputs.at(1); - return prod_nc(input, output, dim); + + if (dim == 0 || dim == 1) { + return prod_nc(input, output, dim); + } else { + return prod_hw(input, output); + } } inline Shape compute_output_shape(const Shape& input_shape, const int64_t& dim) { @@ -61,7 +66,9 @@ inline Shape compute_output_shape(const Shape& input_shape, const int64_t& dim) auto padding = output_shape.padding(); switch (dim) { case 0: - case 1: output_shape[dim] = 1; + case 1: + case 2: + case 3: output_shape[dim] = 1; break; } diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp index a038a3ea08d..531f310cabc 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp @@ -34,6 +34,8 @@ struct Prod { operation::ProgramWithCallbacks prod_nc(const Tensor &input, const Tensor &output, int64_t dim); +operation::ProgramWithCallbacks prod_hw(const Tensor &input, const Tensor &output); + Tensor prod_( const Tensor &input, std::optional> output,