Skip to content

Commit

Permalink
#3900: Add prod support for batch and channels
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthreshx committed Jan 25, 2024
1 parent da3c5fa commit c8f391b
Show file tree
Hide file tree
Showing 10 changed files with 924 additions and 0 deletions.
83 changes: 83 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/test_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from loguru import logger

import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc, skip_for_wormhole_b0

TILE_HEIGHT = 32
TILE_WIDTH = 32


def get_tensors(input_shape, output_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(1, 5, input_shape, dtype=cpu_dtype, requires_grad=True)
print("torch_input ===> ", torch_input)
torch_output = torch.randint(1, 5, output_shape, dtype=cpu_dtype)

tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

return tt_input, tt_output, torch_input


# Dongjin : WH_B0 skips this test due to the problem of sum reduction for w-dim.
@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"input_shape",
(([2, 2, TILE_HEIGHT - 1, TILE_WIDTH - 1]),),
ids=[
"2, 2, TILE_HEIGHT-1,TILE_WIDTH - 1",
],
)
@pytest.mark.parametrize(
"dims",
(
[
0,
],
[
1,
],
),
ids=["0", "1"],
)
def test_moreh_prod_dims(input_shape, dims, device):
output_shape = input_shape.copy()

for dim in dims:
output_shape[dim] = 1

(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device)

torch_output = torch.sum(torch_input, dims, True)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.prod(tt_input, tt_output, dims=dims)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
.to_torch()
)

# print("torch_output ===> ", torch_output)
# print("tt_output_cpu ===> ", tt_output_cpu)

# test for equivalance
# TODO(Dongjin) : check while changing rtol after enabling fp32_dest_acc_en
rtol = atol = 0.12
passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)

logger.info(f"Out passing={passing}")
logger.info(f"Output pcc={output_pcc}")

assert passing
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_w/moreh_sum_w.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_nc/moreh_sum_nc.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp \
tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp \
tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.cpp \
tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp \
Expand Down
59 changes: 59 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/kernels/prod_nc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"

ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }

namespace NAMESPACE {
void MAIN {
const auto num_input_tiles = get_arg_val<uint32_t>(0);
const auto num_output_tiles = get_arg_val<uint32_t>(1);

constexpr auto cb_in0 = tt::CB::c_in0;
constexpr auto cb_in1 = tt::CB::c_in1;
constexpr auto cb_out0 = tt::CB::c_out0;
constexpr auto cb_intermed0 = tt::CB::c_intermed0;
constexpr uint32_t onetile = 1;
constexpr uint32_t dst0 = 0;
constexpr uint32_t dst1 = 1;
constexpr uint32_t first_tile = 0;

binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1);
cb_wait_front(cb_in1, onetile);

for (uint32_t i = 0; i < num_output_tiles; i++) {
bool enable_reload = false;
for (uint32_t j = 0; j < num_input_tiles; ++j) {
bool last_out = (j == num_input_tiles - 1);
uint32_t cb_mul = (enable_reload) ? (cb_intermed0) : (cb_in1);

ACQ();
cb_wait_front(cb_in0, onetile);
if (enable_reload) {
cb_wait_front(cb_intermed0, onetile);
}

mul_tiles_init();
mul_tiles(cb_in0, cb_mul, first_tile, first_tile, dst0);

cb_pop_front(cb_in0, onetile);
if (enable_reload) {
cb_pop_front(cb_intermed0, onetile);
}

uint32_t cb_out = (last_out) ? (cb_out0) : (cb_intermed0);
cb_reserve_back(cb_out, onetile);
pack_tile(dst0, cb_out);
cb_push_back(cb_out, onetile);
REL();
enable_reload = true;
}
}
}
} // namespace NAMESPACE
59 changes: 59 additions & 0 deletions tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_nc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "dataflow_api.h"
#include "tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp"

inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t input_tile_offset, uint32_t HtWt) {
return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * input_tile_offset) + (tile_id % HtWt);
}

void kernel_main() {
const auto input_addr = get_arg_val<uint32_t>(0);
const auto num_input_tiles = get_arg_val<uint32_t>(1);
const auto num_output_tiles = get_arg_val<uint32_t>(2);
const auto input_tile_offset = get_arg_val<uint32_t>(3);
const auto start_id = get_arg_val<uint32_t>(4);
const auto input_is_dram = (get_arg_val<uint32_t>(5) == 1);
const auto HtWt = get_arg_val<uint32_t>(6);
const auto CHtWt = get_arg_val<uint32_t>(7);
const auto dim = get_arg_val<uint32_t>(8);

constexpr uint32_t onetile = 1;
constexpr uint32_t cb_id_in0 = 0;
constexpr uint32_t cb_id_in1 = 1;

union {
float f;
uint32_t u;
} scaler;
scaler.f = 1.0f;
fill_cb_with_value(cb_id_in1, scaler.u);

uint32_t l1_write_addr_in0;
uint32_t input_tile_bytes = get_tile_size(cb_id_in0);
const auto input_data_format = get_dataformat(cb_id_in0);
const InterleavedAddrGenFast<true> dram_input_addrg = {
.bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format};
const InterleavedAddrGenFast<false> l1_input_addrg = {
.bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format};

for (uint32_t i = start_id; i < start_id + num_output_tiles; i++) {
auto read_tile_id = get_read_tile_id(i, dim, CHtWt, HtWt);
for (uint32_t j = 0; j < num_input_tiles; ++j) {
cb_reserve_back(cb_id_in0, onetile);
l1_write_addr_in0 = get_write_ptr(cb_id_in0);
if (input_is_dram) {
noc_async_read_tile(read_tile_id, dram_input_addrg, l1_write_addr_in0);
} else {
noc_async_read_tile(read_tile_id, l1_input_addrg, l1_write_addr_in0);
}
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);
read_tile_id += input_tile_offset;
}
}
}
Loading

0 comments on commit c8f391b

Please sign in to comment.