Skip to content

Commit

Permalink
#4741: Add sum op to tt_dnn
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Jan 22, 2024
1 parent 8958ba3 commit 52add8b
Show file tree
Hide file tree
Showing 31 changed files with 2,423 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from loguru import logger

import tt_lib as ttl
from models.utility_functions import skip_for_wormhole_b0
from models.utility_functions import comp_allclose_and_pcc
from models.utility_functions import comp_allclose_and_pcc, skip_for_wormhole_b0


def get_tensors(input_shape, other_shape, output_shape, require_input_grad, require_other_grad, is_1d, device):
Expand Down Expand Up @@ -232,6 +231,7 @@ def test_moreh_matmul_backward(params, input_b1, input_b2, other_b1, other_b2, r
assert passing


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"params",
(
Expand Down Expand Up @@ -275,6 +275,7 @@ def test_moreh_matmul(params, device):
assert passing


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"params",
(
Expand Down
174 changes: 174 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/test_moreh_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from loguru import logger

import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc, skip_for_wormhole_b0

TILE_HEIGHT = 32
TILE_WIDTH = 32


def get_tensors(input_shape, output_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype, requires_grad=True)
torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)

tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

return tt_input, tt_output, torch_input


def get_backward_tensors(output_grad_shape, input_grad_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_output_grad = torch.randint(-2, 3, output_grad_shape, dtype=cpu_dtype, requires_grad=True)
torch_input_grad = torch.randint(-2, 3, input_grad_shape, dtype=cpu_dtype)

tt_output_grad = ttl.tensor.Tensor(torch_output_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_input_grad = ttl.tensor.Tensor(torch_input_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

return tt_output_grad, tt_input_grad, torch_output_grad


# Dongjin : WH_B0 skips this test due to the problem of sum reduction for w-dim.
@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"input_shape",
(
([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]),
([4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1]),
([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1]),
([8, 8, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1]),
),
ids=[
"1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1",
"4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1",
"4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1",
"8, 8, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1",
],
)
@pytest.mark.parametrize(
"dims",
(
[0],
[0, 1],
[0, 1, 2],
[0, 1, 2, 3],
[0, 1, 3],
[0, 2, 3],
[1],
[1, 2],
[1, 2, 3],
[1, 3],
[2],
[2, 3],
[3],
),
ids=["0", "0,1", "0,1,2", "0,1,2,3", "0,1,3", "0,2,3", "1", "1,2", "1,2,3", "1,3", "2", "2,3", "3"],
)
def test_moreh_sum_dims(input_shape, dims, device):
output_shape = input_shape.copy()

for dim in dims:
output_shape[dim] = 1

(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device)

torch_output = torch.sum(torch_input, dims, True)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.moreh_sum(tt_input, tt_output, dims=dims)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
.to_torch()
)

# test for equivalance
# TODO(Dongjin) : check while changing rtol after enabling fp32_dest_acc_en
rtol = atol = 0.12
passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)

logger.info(f"Out passing={passing}")
logger.info(f"Output pcc={output_pcc}")

assert passing


@pytest.mark.parametrize(
"input_shape",
(
([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]),
([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1]),
([4, 4, TILE_HEIGHT * 30 - 1, TILE_WIDTH * 12 - 1]),
([8, 8, TILE_HEIGHT * 20 - 1, TILE_WIDTH * 20 - 1]),
),
ids=[
"1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1",
"4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1",
"4, 4, TILE_HEIGHT * 30 - 1, TILE_WIDTH * 12 - 1",
"8, 8, TILE_HEIGHT * 20 - 1, TILE_WIDTH * 20 - 1",
],
)
@pytest.mark.parametrize(
"dims",
(
[0],
[0, 1],
[0, 1, 2],
[0, 1, 2, 3],
[0, 1, 3],
[0, 2, 3],
[1],
[1, 2],
[1, 2, 3],
[1, 3],
[2],
[2, 3],
[3],
),
ids=["0", "0,1", "0,1,2", "0,1,2,3", "0,1,3", "0,2,3", "1", "1,2", "1,2,3", "1,3", "2", "2,3", "3"],
)
def test_moreh_sum_backward(input_shape, dims, device):
output_shape = input_shape.copy()

for dim in dims:
output_shape[dim] = 1

(_, _, torch_input) = get_tensors(input_shape, output_shape, device)
(tt_output_grad, tt_input_grad, torch_output_grad) = get_backward_tensors(output_shape, input_shape, device)

torch_output = torch.sum(torch_input, dims, True)
torch_output.backward(torch_output_grad)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_input_grad_cpu = (
ttl.operations.primary.moreh_sum_backward(tt_output_grad, tt_input_grad)
.cpu()
.to(cpu_layout)
.unpad_from_tile(input_shape)
.to_torch()
)

# test for equivalance
rtol = atol = 0.1
passing, output_pcc = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_cpu, pcc=0.999, rtol=rtol, atol=atol)

logger.info(f"Out passing={passing}")
logger.info(f"Output pcc={output_pcc}")

assert passing
7 changes: 6 additions & 1 deletion tt_eager/tt_dnn/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ TT_DNN_SRCS = \
tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp \
tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp \
tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_h/moreh_sum_h.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_w/moreh_sum_w.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_nc/moreh_sum_nc.cpp \
tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward.cpp \
tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.cpp \
tt_eager/tt_dnn/op_library/layernorm/layernorm_op.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul_backward/sum/moreh_sum_multi_core.cpp \
tt_eager/tt_dnn/op_library/moreh_matmul_backward/moreh_matmul_backward_op.cpp \
tt_eager/tt_dnn/op_library/moreh_dot/single_core/moreh_dot_op_single_core.cpp \
tt_eager/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp \
Expand Down
20 changes: 14 additions & 6 deletions tt_eager/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,20 @@ inline void moreh_matmul_validate(

inline Shape compute_output_shape(
const Shape& input_shape, const Shape& other_shape, bool transpose_input, bool transpose_other) {
Shape output_shape{
std::max(input_shape[0], other_shape[0]),
std::max(input_shape[1], other_shape[1]),
(transpose_input) ? (input_shape[3]) : (input_shape[2]),
(transpose_other) ? (other_shape[2]) : (other_shape[3])};
return output_shape;
const auto& input_shape_wo_padding = input_shape.without_padding();
const auto& other_shape_wo_padding = other_shape.without_padding();

auto h = (transpose_input) ? (input_shape[3]) : (input_shape[2]);
auto w = (transpose_other) ? (other_shape[2]) : (other_shape[3]);
auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[3]) : (input_shape_wo_padding[2]);
auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[2]) : (other_shape_wo_padding[3]);

Shape output_shape{std::max(input_shape[0], other_shape[0]), std::max(input_shape[1], other_shape[1]), h, w};
auto padding = output_shape.padding();
padding[2] = Padding::PadDimension{0, h - h_wo_padding};
padding[3] = Padding::PadDimension{0, w - w_wo_padding};

return {Shape(output_shape, padding)};
}

inline Tensor create_output_tensor(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "tt_dnn/op_library/moreh_matmul_backward/moreh_matmul_backward_op.hpp"
#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp"

#include "tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp"
#include "tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp"
Expand All @@ -22,39 +23,6 @@ inline bool is_dot_backward(const Tensor& output_grad, const Tensor& input, cons
return is_scalar(output_grad) && is_1d_tensor(input) && is_1d_tensor(other) && is_same_shape(input, other);
}

////////////////////////////////////////////////////////////////////////////
// MorehSum
////////////////////////////////////////////////////////////////////////////
void MorehSum::validate(const std::vector<Tensor>& inputs) const {
const auto& src = inputs.at(0);
const auto& dst = inputs.at(1);
const auto& src_shape = src.shape();
const auto& dst_shape = dst.shape();

TT_ASSERT(src_shape[2] == dst_shape[2] && src_shape[3] == dst_shape[3]);
TT_ASSERT(src_shape[0] >= dst_shape[0]);
TT_ASSERT(src_shape[1] >= dst_shape[1]);
}

std::vector<Tensor> MorehSum::create_output_tensors(const std::vector<Tensor>& inputs) const {
// Inplace
return {};
}

std::vector<Shape> MorehSum::compute_output_shapes(const std::vector<Tensor>& inputs) const {
// Inplace
return {};
}

operation::ProgramWithCallbacks MorehSum::create_program(
const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) const {
const auto& src = inputs.at(0);
const auto& dst = inputs.at(1);
return moreh_sum_multi_core(src, dst);
}

stl::reflection::Attributes MorehSum::attributes() const { return {}; }

////////////////////////////////////////////////////////////////////////////
// moreh_matmul_backward
////////////////////////////////////////////////////////////////////////////
Expand All @@ -68,6 +36,16 @@ stl::reflection::Attributes MorehSum::attributes() const { return {}; }
std::vector<std::variant<Tensor, char*>> outputs;
outputs.reserve(2);

auto find_reduce_dim = [](const Shape& shape, const Shape& shape2) -> std::vector<int64_t> {
std::vector<int64_t> dims;
for (int i = 0; i < shape.rank() - 1; ++i) {
if (shape[i] != shape2[i]) {
dims.push_back(i);
}
}
return dims;
};

if (input_grad) {
const auto& input_grad_tensor = input_grad->get();
if (is_same_batch_shape(output_grad, input_grad_tensor)) {
Expand All @@ -78,7 +56,8 @@ stl::reflection::Attributes MorehSum::attributes() const { return {}; }
const auto& input_shape = input.shape().without_padding();
const auto& temp_input_grad =
moreh_matmul(output_grad, other, std::nullopt, false, true, output_mem_config);
operation::run(MorehSum{}, {temp_input_grad, input_grad_tensor});
auto reduce_dims = find_reduce_dim(temp_input_grad.shape(), input_grad_tensor.shape());
moreh_sum(temp_input_grad, input_grad_tensor, reduce_dims);
}
outputs.push_back(input_grad_tensor);
} else {
Expand All @@ -92,7 +71,8 @@ stl::reflection::Attributes MorehSum::attributes() const { return {}; }
} else {
const auto& temp_other_grad =
moreh_matmul(input, output_grad, std::nullopt, true, false, output_mem_config);
operation::run(MorehSum{}, {temp_other_grad, other_grad_tensor});
auto reduce_dims = find_reduce_dim(temp_other_grad.shape(), other_grad_tensor.shape());
moreh_sum(temp_other_grad, other_grad_tensor, reduce_dims);
}
outputs.push_back(other_grad_tensor);
} else {
Expand Down
Loading

0 comments on commit 52add8b

Please sign in to comment.