From c8f391b384787d6dbb5a233f650aa8e029d9da02 Mon Sep 17 00:00:00 2001 From: ruthreshk Date: Thu, 25 Jan 2024 15:53:27 +0000 Subject: [PATCH] #3900: Add prod support for batch and channels --- .../unit_testing/test_prod.py | 83 +++++ tt_eager/tt_dnn/module.mk | 2 + .../op_library/prod/kernels/prod_nc.cpp | 59 ++++ .../prod/kernels/reader_prod_nc.cpp | 59 ++++ .../tt_dnn/op_library/prod/kernels/utils.hpp | 298 ++++++++++++++++++ .../prod/kernels/writer_prod_nc.cpp | 39 +++ .../op_library/prod/prod_nc/prod_nc.cpp | 193 ++++++++++++ .../tt_dnn/op_library/prod/prod_nc_op.cpp | 125 ++++++++ .../tt_dnn/op_library/prod/prod_nc_op.hpp | 56 ++++ .../tt_lib/csrc/operations/primary/module.hpp | 10 + 10 files changed, 924 insertions(+) create mode 100644 tests/tt_eager/python_api_testing/unit_testing/test_prod.py create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py new file mode 100644 index 00000000000..a72a5d4d0dd --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger + +import tt_lib as ttl +from models.utility_functions import comp_allclose_and_pcc, skip_for_wormhole_b0 + +TILE_HEIGHT = 32 +TILE_WIDTH = 32 + + +def get_tensors(input_shape, output_shape, device): + torch.manual_seed(2023) + npu_dtype = ttl.tensor.DataType.BFLOAT16 + cpu_dtype = torch.bfloat16 + npu_layout = ttl.tensor.Layout.TILE + + torch_input = torch.randint(1, 5, input_shape, dtype=cpu_dtype, requires_grad=True) + print("torch_input ===> ", torch_input) + torch_output = torch.randint(1, 5, output_shape, dtype=cpu_dtype) + + tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + + return tt_input, tt_output, torch_input + + +# Dongjin : WH_B0 skips this test due to the problem of sum reduction for w-dim. +@skip_for_wormhole_b0() +@pytest.mark.parametrize( + "input_shape", + (([2, 2, TILE_HEIGHT - 1, TILE_WIDTH - 1]),), + ids=[ + "2, 2, TILE_HEIGHT-1,TILE_WIDTH - 1", + ], +) +@pytest.mark.parametrize( + "dims", + ( + [ + 0, + ], + [ + 1, + ], + ), + ids=["0", "1"], +) +def test_moreh_prod_dims(input_shape, dims, device): + output_shape = input_shape.copy() + + for dim in dims: + output_shape[dim] = 1 + + (tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device) + + torch_output = torch.sum(torch_input, dims, True) + + cpu_layout = ttl.tensor.Layout.ROW_MAJOR + tt_output_cpu = ( + ttl.operations.primary.prod(tt_input, tt_output, dims=dims) + .cpu() + .to(cpu_layout) + .unpad_from_tile(output_shape) + .to_torch() + ) + + # print("torch_output ===> ", torch_output) + # print("tt_output_cpu ===> ", tt_output_cpu) + + # test for equivalance + # TODO(Dongjin) : check while changing rtol after enabling fp32_dest_acc_en + rtol = atol = 0.12 + passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol) + + logger.info(f"Out passing={passing}") + logger.info(f"Output pcc={output_pcc}") + + assert passing diff --git a/tt_eager/tt_dnn/module.mk b/tt_eager/tt_dnn/module.mk index 3422992b80f..93182598c3f 100644 --- a/tt_eager/tt_dnn/module.mk +++ b/tt_eager/tt_dnn/module.mk @@ -89,6 +89,8 @@ TT_DNN_SRCS = \ tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_w/moreh_sum_w.cpp \ tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_nc/moreh_sum_nc.cpp \ tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp \ + tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp \ + tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp \ tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward.cpp \ tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.cpp \ tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp \ diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp new file mode 100644 index 00000000000..44c564c26ea --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" + +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } + +namespace NAMESPACE { +void MAIN { + const auto num_input_tiles = get_arg_val(0); + const auto num_output_tiles = get_arg_val(1); + + constexpr auto cb_in0 = tt::CB::c_in0; + constexpr auto cb_in1 = tt::CB::c_in1; + constexpr auto cb_out0 = tt::CB::c_out0; + constexpr auto cb_intermed0 = tt::CB::c_intermed0; + constexpr uint32_t onetile = 1; + constexpr uint32_t dst0 = 0; + constexpr uint32_t dst1 = 1; + constexpr uint32_t first_tile = 0; + + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1); + cb_wait_front(cb_in1, onetile); + + for (uint32_t i = 0; i < num_output_tiles; i++) { + bool enable_reload = false; + for (uint32_t j = 0; j < num_input_tiles; ++j) { + bool last_out = (j == num_input_tiles - 1); + uint32_t cb_mul = (enable_reload) ? (cb_intermed0) : (cb_in1); + + ACQ(); + cb_wait_front(cb_in0, onetile); + if (enable_reload) { + cb_wait_front(cb_intermed0, onetile); + } + + mul_tiles_init(); + mul_tiles(cb_in0, cb_mul, first_tile, first_tile, dst0); + + cb_pop_front(cb_in0, onetile); + if (enable_reload) { + cb_pop_front(cb_intermed0, onetile); + } + + uint32_t cb_out = (last_out) ? (cb_out0) : (cb_intermed0); + cb_reserve_back(cb_out, onetile); + pack_tile(dst0, cb_out); + cb_push_back(cb_out, onetile); + REL(); + enable_reload = true; + } + } +} +} // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp new file mode 100644 index 00000000000..11202694aab --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp" + +inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t input_tile_offset, uint32_t HtWt) { + return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * input_tile_offset) + (tile_id % HtWt); +} + +void kernel_main() { + const auto input_addr = get_arg_val(0); + const auto num_input_tiles = get_arg_val(1); + const auto num_output_tiles = get_arg_val(2); + const auto input_tile_offset = get_arg_val(3); + const auto start_id = get_arg_val(4); + const auto input_is_dram = (get_arg_val(5) == 1); + const auto HtWt = get_arg_val(6); + const auto CHtWt = get_arg_val(7); + const auto dim = get_arg_val(8); + + constexpr uint32_t onetile = 1; + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + + union { + float f; + uint32_t u; + } scaler; + scaler.f = 1.0f; + fill_cb_with_value(cb_id_in1, scaler.u); + + uint32_t l1_write_addr_in0; + uint32_t input_tile_bytes = get_tile_size(cb_id_in0); + const auto input_data_format = get_dataformat(cb_id_in0); + const InterleavedAddrGenFast dram_input_addrg = { + .bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format}; + const InterleavedAddrGenFast l1_input_addrg = { + .bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format}; + + for (uint32_t i = start_id; i < start_id + num_output_tiles; i++) { + auto read_tile_id = get_read_tile_id(i, dim, CHtWt, HtWt); + for (uint32_t j = 0; j < num_input_tiles; ++j) { + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + if (input_is_dram) { + noc_async_read_tile(read_tile_id, dram_input_addrg, l1_write_addr_in0); + } else { + noc_async_read_tile(read_tile_id, l1_input_addrg, l1_write_addr_in0); + } + noc_async_read_barrier(); + cb_push_back(cb_id_in0, onetile); + read_tile_id += input_tile_offset; + } + } +} diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp b/tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp new file mode 100644 index 00000000000..c50d0254e34 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp @@ -0,0 +1,298 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void fill_cb_with_value(uint32_t cb_id, uint32_t value, int32_t num_of_elems = 1024) { + cb_reserve_back(cb_id, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_id)); + for (int j = 0; j < num_of_elems; j++) { + ptr[j] = uint16_t(value >> 16); + } + cb_push_back(cb_id, 1); +} + +void generate_mask_h_w(uint32_t cb_mask_h_w, uint32_t mask_h, uint32_t mask_w, uint32_t single_tile_size = 2048) { + union { + float f; + uint32_t u; + } one; + one.f = 1.0f; + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + + const auto u16_one = uint16_t(one.u >> 16); + const auto u16_zero = uint16_t(zero.u >> 16); + + cb_reserve_back(cb_mask_h_w, 2); + + // mask_h + // first tile ptr + auto mask_h_ptr = reinterpret_cast(get_write_ptr(cb_mask_h_w)); + for (uint32_t w = 0; w < 16; w++) { + // sub tile 0 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) { + mask_h_0 = 16; + } + uint32_t h = 0; + for (; h < mask_h_0; h++) { + mask_h_ptr[h * 16 + w] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w] = u16_zero; + } + } + + // sub tile 1 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) { + mask_h_0 = 16; + } + uint32_t h = 0; + for (; h < mask_h_0; h++) { + mask_h_ptr[h * 16 + w + 256] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 256] = u16_zero; + } + } + + // sub tile 2 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + mask_h_ptr[h * 16 + w + 512] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 512] = u16_zero; + } + } + + // sub tile 3 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + mask_h_ptr[h * 16 + w + 768] = u16_one; + } + for (; h < 16; h++) { + mask_h_ptr[h * 16 + w + 768] = u16_zero; + } + } + } + + // mask_w + // second tile ptr + auto mask_w_ptr = reinterpret_cast(get_write_ptr(cb_mask_h_w) + single_tile_size); + for (uint32_t h = 0; h < 16; h++) { + // sub tile 0 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) { + mask_w_0 = 16; + } + uint32_t w = 0; + for (; w < mask_w_0; w++) { + mask_w_ptr[h * 16 + w] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w] = u16_zero; + } + } + + // sub tile 1 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + mask_w_ptr[h * 16 + w + 256] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 256] = u16_zero; + } + } + + // sub tile 2 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) { + mask_w_0 = 16; + } + uint32_t w = 0; + for (; w < mask_w_0; w++) { + mask_w_ptr[h * 16 + w + 512] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 512] = u16_zero; + } + } + + // sub tile 3 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + mask_w_ptr[h * 16 + w + 768] = u16_one; + } + for (; w < 16; w++) { + mask_w_ptr[h * 16 + w + 768] = u16_zero; + } + } + } + + cb_push_back(cb_mask_h_w, 2); +} + +void generate_mask_w(uint32_t cb_mask, uint32_t mask_w) { + union { + float f; + uint32_t u; + } one; + one.f = 1.0f; + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + + cb_reserve_back(cb_mask, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_mask)); + + for (uint32_t h = 0; h < 16; h++) { + // sub tile 0 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) + mask_w_0 = 16; + uint32_t w = 0; + for (; w < mask_w_0; w++) { + ptr[h * 16 + w] = uint16_t(one.u >> 16); + } + for (; w < 16; w++) { + ptr[h * 16 + w] = uint16_t(zero.u >> 16); + } + } + + // sub tile 1 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + ptr[h * 16 + w + 256] = uint16_t(one.u >> 16); + } + for (; w < 16; w++) { + ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); + } + } + + // sub tile 2 + { + uint32_t mask_w_0 = mask_w; + if (mask_w_0 >= 16) + mask_w_0 = 16; + uint32_t w = 0; + for (; w < mask_w_0; w++) { + ptr[h * 16 + w + 512] = uint16_t(one.u >> 16); + } + for (; w < 16; w++) { + ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); + } + } + + // sub tile 3 + { + uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; + uint32_t w = 0; + for (; w < mask_w_1; w++) { + ptr[h * 16 + w + 768] = uint16_t(one.u >> 16); + } + for (; w < 16; w++) { + ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); + } + } + } + + cb_push_back(cb_mask, 1); +} + +void generate_mask_h(uint32_t cb_mask, uint32_t mask_h) { + union { + float f; + uint32_t u; + } one; + one.f = 1.0f; + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + + cb_reserve_back(cb_mask, 1); + auto ptr = reinterpret_cast(get_write_ptr(cb_mask)); + + for (uint32_t w = 0; w < 16; w++) { + // sub tile 0 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) + mask_h_0 = 16; + uint32_t h = 0; + for (; h < mask_h_0; h++) { + ptr[h * 16 + w] = uint16_t(one.u >> 16); + } + for (; h < 16; h++) { + ptr[h * 16 + w] = uint16_t(zero.u >> 16); + } + } + + // sub tile 1 + { + uint32_t mask_h_0 = mask_h; + if (mask_h_0 >= 16) + mask_h_0 = 16; + uint32_t h = 0; + for (; h < mask_h_0; h++) { + ptr[h * 16 + w + 256] = uint16_t(one.u >> 16); + } + for (; h < 16; h++) { + ptr[h * 16 + w + 256] = uint16_t(zero.u >> 16); + } + } + + // sub tile 2 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + ptr[h * 16 + w + 512] = uint16_t(one.u >> 16); + } + for (; h < 16; h++) { + ptr[h * 16 + w + 512] = uint16_t(zero.u >> 16); + } + } + + // sub tile 3 + { + uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; + uint32_t h = 0; + for (; h < mask_h_1; h++) { + ptr[h * 16 + w + 768] = uint16_t(one.u >> 16); + } + for (; h < 16; h++) { + ptr[h * 16 + w + 768] = uint16_t(zero.u >> 16); + } + } + } + + cb_push_back(cb_mask, 1); +} diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_nc.cpp new file mode 100644 index 00000000000..0efd154c76f --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_nc.cpp @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + const auto output_addr = get_arg_val(0); + const auto num_tiles = get_arg_val(1); + const auto start_id = get_arg_val(2); + const auto output_is_dram = (get_arg_val(3) == 1); + + constexpr uint32_t cb_id_out = 16; + constexpr uint32_t onetile = 1; + + uint32_t output_tile_bytes = get_tile_size(cb_id_out); + const auto output_data_format = get_dataformat(cb_id_out); + + const InterleavedAddrGenFast dram_output_addrg = { + .bank_base_address = output_addr, .page_size = output_tile_bytes, .data_format = output_data_format}; + const InterleavedAddrGenFast l1_output_addrg = { + .bank_base_address = output_addr, .page_size = output_tile_bytes, .data_format = output_data_format}; + + for (uint32_t i = start_id; i < start_id + num_tiles; i++) { + uint32_t write_tile_id = i; + cb_wait_front(cb_id_out, onetile); + + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + if (output_is_dram) { + noc_async_write_tile(write_tile_id, dram_output_addrg, l1_read_addr); + } else { + noc_async_write_tile(write_tile_id, l1_output_addrg, l1_read_addr); + } + noc_async_write_barrier(); + cb_pop_front(cb_id_out, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp new file mode 100644 index 00000000000..e11f3966946 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp @@ -0,0 +1,193 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { +using namespace constants; +namespace operations { + +namespace primary { + +operation::ProgramWithCallbacks pthread_rwlock_rdlock(const Tensor &input, const Tensor &output, int64_t dim) { + TT_ASSERT(dim == 0 || dim == 1); + + //////////////////////////////////////////////////////////////////////////// + // Device Setup + //////////////////////////////////////////////////////////////////////////// + auto *device = input.device(); + auto program = Program(); + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + const auto cb_data_format = datatype_to_dataformat_converter(output.dtype()); + const auto single_tile_size = detail::TileSize(cb_data_format); + + const auto &input_shape = input.shape(); + const auto &input_shape_without_padding = input_shape.without_padding(); + + const auto N = input_shape[0]; + const auto C = input_shape[1]; + const auto Ht = input_shape[2] / TILE_HEIGHT; + const auto Wt = input_shape[3] / TILE_WIDTH; + const auto HtWt = Ht * Wt; + const auto CHtWt = C * Ht * Wt; + const auto num_reduce_input_tile = input_shape[dim]; + const auto input_tile_offset = (dim == 0) ? (CHtWt) : (HtWt); + const auto num_output_tiles = output.volume() / TILE_HW; + + log_debug(LogTest, "N {} C {} Ht {} Wt {}", N, C, Ht, Wt); + log_debug( + LogTest, + "dim {} num_reduce_input_tile {} input_tile_offset {}, num_output_tiles {}", + dim, + num_reduce_input_tile, + input_tile_offset, + num_output_tiles); + + //////////////////////////////////////////////////////////////////////////// + // Core Setup + //////////////////////////////////////////////////////////////////////////// + CoreGridDesc core_grid(device); + const auto num_cores_y = core_grid.y_; + CoreCoord core_grid_coord = {.x = core_grid.x_, .y = num_cores_y}; + + const uint32_t in0_t = 2; // input + const uint32_t in1_t = 1; // zero + const uint32_t intermed0_t = 1; // accumulated mul + const uint32_t out0_t = 2; // output + const auto + [num_cores_to_be_used, + all_cores, + core_group_1, + core_group_2, + num_cols_per_core_group_1, + num_cols_per_core_group_2] = tt_metal::split_work_to_cores(core_grid_coord, num_output_tiles); + + //////////////////////////////////////////////////////////////////////////// + // CircularBuffer Setup + //////////////////////////////////////////////////////////////////////////// + CreateCircularBuffer( + program, + all_cores, + cb_data_format, + { + {CB::c_in0, in0_t}, // input + {CB::c_in1, in1_t}, // zero + {CB::c_intermed0, intermed0_t}, // accumulated mul + {CB::c_out0, out0_t}, // output + }); + + //////////////////////////////////////////////////////////////////////////// + // DataMovementKernel SetUp + //////////////////////////////////////////////////////////////////////////// + std::vector reader_compile_time_args; + std::vector writer_compile_time_args; + const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp"; + const auto writer_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/writer_prod_nc.cpp"; + const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args); + const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args); + + //////////////////////////////////////////////////////////////////////////// + // ComputeKernel SetUp + //////////////////////////////////////////////////////////////////////////// + const std::vector compute_args_group_1{num_cols_per_core_group_1}; + std::map compute_defines; + const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp"; + const auto compute_kernel_1_id = CreateComputeKernel( + program, compute_kernel_file, {core_group_1, num_cols_per_core_group_1, compute_args_group_1}, compute_defines); + + std::optional compute_kernel_2_id = std::nullopt; + if (!core_group_2.ranges().empty()) { + const std::vector compute_args_group_2{num_cols_per_core_group_2}; + compute_kernel_2_id = CreateComputeKernel( + program, + compute_kernel_file, + {core_group_2, num_cols_per_core_group_2, compute_args_group_2}, + compute_defines); + } + + //////////////////////////////////////////////////////////////////////////// + // RuntimeArgs SetUp + //////////////////////////////////////////////////////////////////////////// + for (uint32_t i = 0, tile_offset = 0; i < num_cores_to_be_used; ++i) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + + uint32_t num_tiles_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_cols_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_cols_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges."); + } + + SetRuntimeArgs( + program, + reader_kernel_id, + core, + {input.buffer()->address(), + num_reduce_input_tile, + num_tiles_per_core, + input_tile_offset, + tile_offset, + static_cast(is_dram(input)), + HtWt, + CHtWt, + static_cast(dim) + }); + + SetRuntimeArgs( + program, + writer_kernel_id, + core, + {output.buffer()->address(), num_tiles_per_core, tile_offset, static_cast(is_dram(output))}); + + if (core_group_1.core_coord_in_core_ranges(core)) { + SetRuntimeArgs(program, compute_kernel_1_id, core, {num_reduce_input_tile, num_tiles_per_core}); + } else if (core_group_2.core_coord_in_core_ranges(core)) { + TT_ASSERT(compute_kernel_2_id.has_value()); + SetRuntimeArgs(program, compute_kernel_2_id.value(), core, {num_reduce_input_tile, num_tiles_per_core}); + } else { + TT_ASSERT(false, "Core not in specified core ranges."); + } + tile_offset += num_tiles_per_core; + } + + auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, num_cores_to_be_used, num_cores_y]( + const void *operation, + const Program &program, + const std::vector &input_tensors, + const std::vector> &, + const std::vector &output_tensors) { + const auto *input_buffer = input_tensors.at(0).buffer(); + const auto *output_buffer = input_tensors.at(1).buffer(); + for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + { + auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = input_buffer->address(); + SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); + } + + { + auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = output_buffer->address(); + SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); + } + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp new file mode 100644 index 00000000000..f5902f80d43 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { +using namespace constants; +namespace operations { +namespace primary { + +//////////////////////////////////////////////////////////////////////////// +// Prod +//////////////////////////////////////////////////////////////////////////// +void Prod::validate(const std::vector& inputs) const { + TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); + const auto& input = inputs.at(0); + const auto& output = inputs.at(1); + + auto input_shape = input.shape(); + const auto& output_shape = output.shape(); + auto input_shape_wo_padding = input.shape().without_padding(); + const auto& output_shape_wo_padding = output.shape().without_padding(); + + if (dim == 0 || dim == 1) { + input_shape[dim] = 1; + input_shape_wo_padding[dim] = 1; + } + + for (int i = 0; i < input_shape.rank(); ++i) { + TT_ASSERT(input_shape[i] == output_shape[i]); + TT_ASSERT(input_shape_wo_padding[i] == output_shape_wo_padding[i]); + } +} + +std::vector Prod::create_output_tensors(const std::vector& inputs) const { + // Inplace + return {}; +} + +std::vector Prod::compute_output_shapes(const std::vector& inputs) const { + // Inplace + return {}; +} + +operation::ProgramWithCallbacks Prod::create_program( + const std::vector& inputs, std::vector& outputs) const { + TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); + auto& input = inputs.at(0); + auto& output = inputs.at(1); + + if (dim == 0 || dim == 1) { + return prod_nc(input, output, dim); + } + else { + return prod_w(input, output); + } +} + +inline Shape compute_output_shape(const Shape& input_shape, const int64_t& dim) { + auto output_shape = input_shape; + auto padding = output_shape.padding(); + switch (dim) { + case 0: + case 1: output_shape[dim] = 1; + break; + } + + return {Shape(output_shape, padding)}; +} + +inline Tensor create_output_tensor( + const Tensor& input_tensor, const Shape& output_shape, const MemoryConfig& mem_config) { + TT_ASSERT(input_tensor.storage_type() == StorageType::DEVICE); + return create_device_tensor(output_shape, input_tensor.dtype(), Layout::TILE, input_tensor.device(), mem_config); +} + +// output as arg +Tensor prod_(const Tensor& input, const Tensor& output, const int64_t& dim) { + operation::run(Prod{.dim = dim}, {input, output}); + return output; +} + +// output creation inside +Tensor prod_(const Tensor& input, const int64_t& dim, const MemoryConfig& mem_config) { + const auto& input_shape = input.shape(); + const auto& output_shape = compute_output_shape(input_shape, dim); + auto output = create_output_tensor(input, output_shape, mem_config); + + const auto& output_shape_wo_padding = output.shape().without_padding(); + operation::run(Prod{.dim = dim}, {input, output}); + return output; +} + +Tensor prod( + const Tensor& input, + const Tensor& output, + std::vector& dims, + const MemoryConfig& mem_config) { + // reduce for all dims + if (dims.empty()) { + dims = {0, 1, 2, 3}; + } + + std::vector sorted_dims = dims; + std::sort(sorted_dims.begin(), sorted_dims.end()); + + auto temp_input = input; + for (uint32_t i = dims.size() - 1; i > 0; i--) { + log_debug(LogTest, "{}:{} dim {}", __func__, __LINE__, sorted_dims[i]); + auto temp_output = prod_(temp_input, sorted_dims[i], mem_config); + temp_input = temp_output; + } + log_debug(LogTest, "{}:{} dim {}", __func__, __LINE__, sorted_dims.front()); + prod_(temp_input, output, sorted_dims.front()); + return output; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp new file mode 100644 index 00000000000..e2516c32b90 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tensor/tensor.hpp" + +namespace tt { + +namespace operations { + +namespace primary { + +using namespace tt_metal; + +struct Prod { + int64_t dim; + void validate(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector create_output_tensors(const std::vector &inputs) const; + operation::ProgramWithCallbacks create_program( + const std::vector &inputs, std::vector &outputs) const; + stl::reflection::Attributes attributes() const; + static constexpr auto attribute_names = std::make_tuple("dim"); + const auto attribute_values() const { return std::make_tuple(std::cref(this->dim)); } +}; + +operation::ProgramWithCallbacks prod_nc(const Tensor &input, const Tensor &output, int64_t dim); +// revised from reduce_op +operation::ProgramWithCallbacks prod_w(const Tensor &a, const Tensor &output); +// operation::ProgramWithCallbacks prod_h(const Tensor &a, const Tensor &output); + +Tensor prod_( + const Tensor &input, + std::optional> output, + const int64_t &dim, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +Tensor prod( + const Tensor &input, + const Tensor &output, + std::vector &dims, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +} // namespace primary + +} // namespace operations + +} // namespace tt diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index 69c6d4cc721..e7422bc8cb4 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -20,6 +20,7 @@ #include "tt_dnn/op_library/softmax/softmax_op.hpp" #include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" #include "tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.hpp" +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" #include "tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" #include "tt_dnn/op_library/moreh_arange/moreh_arange_op.hpp" @@ -525,6 +526,15 @@ void py_module(py::module& m_primary) { py::arg("dims").noconvert() = std::vector(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, "Performs sum operation. Returns an output tensor."); + m_primary.def( + "prod", + &prod, + py::arg("input").noconvert(), + py::arg("output").noconvert(), + py::kw_only(), + py::arg("dims").noconvert() = std::vector(), + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + "Performs product operation. Returns an output tensor."); m_primary.def( "moreh_sum_backward", &moreh_sum_backward,