diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax.py index 37561ad5009..d2ca0d04e40 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax.py @@ -7,10 +7,8 @@ import tt_lib from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs -from models.utility_functions import skip_for_wormhole_b0 -@skip_for_wormhole_b0("skip for WHB0 until @rtawfik/reduce_max_w_whb0_debug@ branch merge to main") @pytest.mark.parametrize( "input_shapes", ( diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmin.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmin.py index a51f2c968c9..0719e1434fc 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmin.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmin.py @@ -7,10 +7,8 @@ import tt_lib from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs -from models.utility_functions import skip_for_wormhole_b0 -@skip_for_wormhole_b0("skip for WHB0 until @rtawfik/reduce_max_w_whb0_debug@ branch merge to main") @pytest.mark.parametrize( "input_shapes", ( diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 1f1c8300f9b..aa05bb961a5 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1313,41 +1313,53 @@ Tensor repeat(const Tensor& input_a, const Shape& shape_b, const MemoryConfig& o // Argmax returns the index of maximum element in the tensor Tensor _argmax(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_mem_config) { - uint32_t dim = input_a.shape().get_normalized_index(_dim); auto& input_shape = input_a.shape(); + TT_FATAL(input_shape.rank() == 4,"supported for rank-4 tensors at this time"); + + uint32_t dim = input_shape.get_normalized_index(_dim); int size = input_a.volume(); - TT_ASSERT((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]"); - if (dim == 3 ) + + TT_FATAL((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]"); + if (dim == (input_shape.rank() - 1)) { Tensor tindex = tt::numpy::index_width(input_shape, DataType::BFLOAT16); Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::W); Tensor max_tensor = zeros_like(input_a, output_mem_config); max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::W, output_mem_config); + max_val.deallocate(); Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config); + max_tensor.deallocate(); Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor midx = full_like(max_indices, size); - Tensor result = where(eqz(max_indices), midx, max_indices); + Tensor result = where(eqz(max_indices), midx, max_indices, output_mem_config); + max_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::W); Tensor res_index = zeros_like(result, output_mem_config); - result = where(eq(result, full_like(result, size)), res_index, result); + result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config); res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::W, output_mem_config); + result.deallocate(); std::vector permute_dims = {0, 1, 3, 2}; Tensor transpose_res = permute(res_index,permute_dims,output_mem_config); return transpose_res; } - else if (dim == 2 ) + else if (dim == (input_shape.rank() - 2)) { Tensor tindex = tt::numpy::index_height(input_shape, DataType::BFLOAT16); Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::H); Tensor max_tensor = zeros_like(input_a, output_mem_config); max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::H, output_mem_config); + max_val.deallocate(); Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config); + max_tensor.deallocate(); Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor midx = full_like(max_indices, size); - Tensor result = where(eqz(max_indices), midx, max_indices); + Tensor result = where(eqz(max_indices), midx, max_indices, output_mem_config); + max_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::H); Tensor res_index = zeros_like(result, output_mem_config); - result = where(eq(result, full_like(result, size)), res_index, result); + result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config); res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::H, output_mem_config); return res_index; } @@ -1359,13 +1371,17 @@ Tensor _argmax(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_m Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::HW); Tensor max_tensor = zeros_like(input_a, output_mem_config); max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + max_val.deallocate(); Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config); + max_tensor.deallocate(); Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor result = full_like(max_indices, size); - result = where(eqz(max_indices), result, max_indices); + result = where(eqz(max_indices), result, max_indices, output_mem_config); + max_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::HW); Tensor res_index = zeros_like(result, output_mem_config); - result = where(eq(result, full_like(result, size)), res_index, result); + result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config); res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); return res_index; } @@ -1379,39 +1395,51 @@ Tensor argmax(const Tensor& input_a, int64_t dim, const MemoryConfig& output_mem // Argmax returns the index of maximum element in the tensor Tensor _argmin(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_mem_config) { - uint32_t dim = input_a.shape().get_normalized_index(_dim); auto& input_shape = input_a.shape(); + TT_FATAL(input_shape.rank() == 4,"supported for rank-4 tensors at this time"); + + uint32_t dim = input_shape.get_normalized_index(_dim); int size = input_a.volume(); - TT_ASSERT((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]"); - if (dim == 3 ) + TT_FATAL((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]"); + + if (dim == (input_shape.rank() - 1)) { Tensor tindex = tt::numpy::index_width(input_shape, DataType::BFLOAT16); Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::W); Tensor min_tensor = zeros_like(input_a, output_mem_config); min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::W, output_mem_config); + min_val.deallocate(); Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config); + min_tensor.deallocate(); Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor midx = full_like(min_indices, size); Tensor result = where(eqz(min_indices), midx, min_indices); + min_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::W); Tensor res_index = zeros_like(result, output_mem_config); result = where(eq(result, full_like(result, size)), res_index, result); res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::W, output_mem_config); + result.deallocate(); std::vector permute_dims = {0, 1, 3, 2}; Tensor transpose_res = permute(res_index,permute_dims,output_mem_config); return transpose_res; } - else if (dim == 2 ) + else if (dim == (input_shape.rank() - 2)) { Tensor tindex = tt::numpy::index_height(input_shape, DataType::BFLOAT16); Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::H); Tensor min_tensor = zeros_like(input_a, output_mem_config); min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::H, output_mem_config); + min_val.deallocate(); Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config); + min_tensor.deallocate(); Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor midx = full_like(min_indices, size); Tensor result = where(eqz(min_indices), midx, min_indices); + min_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::H); Tensor res_index = zeros_like(result, output_mem_config); result = where(eq(result, full_like(result, size)), res_index, result); @@ -1426,10 +1454,14 @@ Tensor _argmin(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_m Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::HW); Tensor min_tensor = zeros_like(input_a, output_mem_config); min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + min_val.deallocate(); Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config); + min_tensor.deallocate(); Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config); + cmp_results.deallocate(); Tensor result = full_like(min_indices, size); result = where(eqz(min_indices), result, min_indices); + min_indices.deallocate(); result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::HW); Tensor res_index = zeros_like(result, output_mem_config); result = where(eq(result, full_like(result, size)), res_index, result); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index d8dbf98e804..6a70f93c721 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -218,16 +218,14 @@ namespace tt::tt_metal::detail{ detail::bind_unary_op_with_param( m_tensor, "argmax", &argmax, py::arg("dim"), - R"doc(Returns the indices of the maximum value of all elements in the ``input`` tensor.) - Currently supported dimension are H, W, and HW. By default it returns argmax across HW. Suported dimension [1, 1, H, W])doc", + R"doc(Returns the indices of the maximum value of all elements in the ``input`` tensor. Currently supported dimension are H, W, and HW. By default it returns argmax across HW. Suported shapes [1, 1, H, W])doc", R"doc("dim", "int", "")doc" ); detail::bind_unary_op_with_param( m_tensor, "argmin", &argmin, py::arg("dim"), - R"doc(Returns the indices of the minimum value of all elements in the ``input`` tensor.) - Currently supported dimension are H, W, and HW. By default it returns argmin across HW. Suported dimension [1, 1, H, W])doc", + R"doc(Returns the indices of the minimum value of all elements in the ``input`` tensor. Currently supported dimension are H, W, and HW. By default it returns argmin across HW. Suported shapes [1, 1, H, W])doc", R"doc("dim", "int", "")doc" );