From 9006951fc73a0c40a79c4c63873285b63763f218 Mon Sep 17 00:00:00 2001 From: Muthu Date: Thu, 25 Jan 2024 23:36:58 +0000 Subject: [PATCH] #3900: Add prod support for batch and channels --- .../python_api_testing/unit_testing/test_prod.py | 11 ++++++++++- tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py index 4be0522e832..bdd368ae0e5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_prod.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod.py @@ -53,12 +53,21 @@ def get_tensors(input_shape, output_shape, device): [ 1, ], + [ + 2, + ], + [ + 3, + ], ), - ids=["0", "1"], + ids=["0", "1", "2", "3"], ) def test_moreh_prod_dims(input_shape, dims, device): output_shape = input_shape.copy() + if dims[0] in [2, 3]: + pytest.skip(f"Dim {dims[0]} not supported at this time.") + for dim in dims: output_shape[dim] = 1 diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp index 8861313934f..a80c06b123f 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp @@ -22,6 +22,7 @@ void Prod::validate(const std::vector& inputs) const { const auto& output = inputs.at(1); auto input_shape = input.shape(); + TT_ASSERT((input_shape.rank()), "rank should be 4"); const auto& output_shape = output.shape(); auto input_shape_wo_padding = input.shape().without_padding(); const auto& output_shape_wo_padding = output.shape().without_padding(); @@ -32,8 +33,8 @@ void Prod::validate(const std::vector& inputs) const { } for (int i = 0; i < input_shape.rank(); ++i) { - TT_ASSERT(input_shape[i] == output_shape[i]); - TT_ASSERT(input_shape_wo_padding[i] == output_shape_wo_padding[i]); + TT_FATAL(input_shape[i] == output_shape[i]); + TT_FATAL(input_shape_wo_padding[i] == output_shape_wo_padding[i]); } } @@ -49,7 +50,6 @@ std::vector Prod::compute_output_shapes(const std::vector& inputs operation::ProgramWithCallbacks Prod::create_program( const std::vector& inputs, std::vector& outputs) const { - TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); auto& input = inputs.at(0); auto& output = inputs.at(1);