Skip to content

Commit

Permalink
#3900: Add prod support for batch and channels
Browse files Browse the repository at this point in the history
  • Loading branch information
Muthu authored and ruthreshx committed Jan 31, 2024
1 parent d43c1a5 commit 9006951
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
11 changes: 10 additions & 1 deletion tests/tt_eager/python_api_testing/unit_testing/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,21 @@ def get_tensors(input_shape, output_shape, device):
[
1,
],
[
2,
],
[
3,
],
),
ids=["0", "1"],
ids=["0", "1", "2", "3"],
)
def test_moreh_prod_dims(input_shape, dims, device):
output_shape = input_shape.copy()

if dims[0] in [2, 3]:
pytest.skip(f"Dim {dims[0]} not supported at this time.")

for dim in dims:
output_shape[dim] = 1

Expand Down
6 changes: 3 additions & 3 deletions tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void Prod::validate(const std::vector<Tensor>& inputs) const {
const auto& output = inputs.at(1);

auto input_shape = input.shape();
TT_ASSERT((input_shape.rank()), "rank should be 4");
const auto& output_shape = output.shape();
auto input_shape_wo_padding = input.shape().without_padding();
const auto& output_shape_wo_padding = output.shape().without_padding();
Expand All @@ -32,8 +33,8 @@ void Prod::validate(const std::vector<Tensor>& inputs) const {
}

for (int i = 0; i < input_shape.rank(); ++i) {
TT_ASSERT(input_shape[i] == output_shape[i]);
TT_ASSERT(input_shape_wo_padding[i] == output_shape_wo_padding[i]);
TT_FATAL(input_shape[i] == output_shape[i]);
TT_FATAL(input_shape_wo_padding[i] == output_shape_wo_padding[i]);
}
}

Expand All @@ -49,7 +50,6 @@ std::vector<Shape> Prod::compute_output_shapes(const std::vector<Tensor>& inputs

operation::ProgramWithCallbacks Prod::create_program(
const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) const {
TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3");
auto& input = inputs.at(0);
auto& output = inputs.at(1);

Expand Down

0 comments on commit 9006951

Please sign in to comment.