Skip to content

Commit

Permalink
#4125: Update argmin and argmax support with assert condition and dea…
Browse files Browse the repository at this point in the history
…llocate

     : assert -> fatal
  • Loading branch information
umadevimcw authored and muthutt committed Jan 24, 2024
1 parent 3e27fb5 commit ae373f4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
import tt_lib
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs
from models.utility_functions import skip_for_wormhole_b0


@skip_for_wormhole_b0("skip for WHB0 until @rtawfik/reduce_max_w_whb0_debug@ branch merge to main")
@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
import tt_lib
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs
from models.utility_functions import skip_for_wormhole_b0


@skip_for_wormhole_b0("skip for WHB0 until @rtawfik/reduce_max_w_whb0_debug@ branch merge to main")
@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
60 changes: 46 additions & 14 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,41 +1313,53 @@ Tensor repeat(const Tensor& input_a, const Shape& shape_b, const MemoryConfig& o
// Argmax returns the index of maximum element in the tensor
Tensor _argmax(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_mem_config) {

uint32_t dim = input_a.shape().get_normalized_index(_dim);
auto& input_shape = input_a.shape();
TT_FATAL(input_shape.rank() == 4,"supported for rank-4 tensors at this time");

uint32_t dim = input_shape.get_normalized_index(_dim);
int size = input_a.volume();
TT_ASSERT((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]");
if (dim == 3 )

TT_FATAL((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]");
if (dim == (input_shape.rank() - 1))
{
Tensor tindex = tt::numpy::index_width<bfloat16>(input_shape, DataType::BFLOAT16);
Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::W);
Tensor max_tensor = zeros_like(input_a, output_mem_config);
max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::W, output_mem_config);
max_val.deallocate();
Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config);
max_tensor.deallocate();
Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor midx = full_like(max_indices, size);
Tensor result = where(eqz(max_indices), midx, max_indices);
Tensor result = where(eqz(max_indices), midx, max_indices, output_mem_config);
max_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::W);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config);
res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::W, output_mem_config);
result.deallocate();
std::vector<int64_t> permute_dims = {0, 1, 3, 2};
Tensor transpose_res = permute(res_index,permute_dims,output_mem_config);
return transpose_res;
}
else if (dim == 2 )
else if (dim == (input_shape.rank() - 2))
{
Tensor tindex = tt::numpy::index_height<bfloat16>(input_shape, DataType::BFLOAT16);
Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::H);
Tensor max_tensor = zeros_like(input_a, output_mem_config);
max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::H, output_mem_config);
max_val.deallocate();
Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config);
max_tensor.deallocate();
Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor midx = full_like(max_indices, size);
Tensor result = where(eqz(max_indices), midx, max_indices);
Tensor result = where(eqz(max_indices), midx, max_indices, output_mem_config);
max_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::H);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config);
res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::H, output_mem_config);
return res_index;
}
Expand All @@ -1359,13 +1371,17 @@ Tensor _argmax(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_m
Tensor max_val = reduce(input_a, ReduceOpMath::MAX, ReduceOpDim::HW);
Tensor max_tensor = zeros_like(input_a, output_mem_config);
max_tensor = bcast(max_tensor, max_val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
max_val.deallocate();
Tensor cmp_results = eq(input_a, max_tensor, std::nullopt, output_mem_config);
max_tensor.deallocate();
Tensor max_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor result = full_like(max_indices, size);
result = where(eqz(max_indices), result, max_indices);
result = where(eqz(max_indices), result, max_indices, output_mem_config);
max_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::HW);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
result = where(eq(result, full_like(result, size)), res_index, result, output_mem_config);
res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
return res_index;
}
Expand All @@ -1379,39 +1395,51 @@ Tensor argmax(const Tensor& input_a, int64_t dim, const MemoryConfig& output_mem
// Argmax returns the index of maximum element in the tensor
Tensor _argmin(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_mem_config) {

uint32_t dim = input_a.shape().get_normalized_index(_dim);
auto& input_shape = input_a.shape();
TT_FATAL(input_shape.rank() == 4,"supported for rank-4 tensors at this time");

uint32_t dim = input_shape.get_normalized_index(_dim);
int size = input_a.volume();
TT_ASSERT((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]");

if (dim == 3 )
TT_FATAL((input_shape[0] == 1 && input_shape[1] == 1), "Unsupported shapes, supported shapes [1, 1, N, M]");

if (dim == (input_shape.rank() - 1))
{
Tensor tindex = tt::numpy::index_width<bfloat16>(input_shape, DataType::BFLOAT16);
Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::W);
Tensor min_tensor = zeros_like(input_a, output_mem_config);
min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::W, output_mem_config);
min_val.deallocate();
Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config);
min_tensor.deallocate();
Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor midx = full_like(min_indices, size);
Tensor result = where(eqz(min_indices), midx, min_indices);
min_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::W);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
res_index = bcast(res_index, result, BcastOpMath::ADD, BcastOpDim::W, output_mem_config);
result.deallocate();
std::vector<int64_t> permute_dims = {0, 1, 3, 2};
Tensor transpose_res = permute(res_index,permute_dims,output_mem_config);
return transpose_res;
}
else if (dim == 2 )
else if (dim == (input_shape.rank() - 2))
{
Tensor tindex = tt::numpy::index_height<bfloat16>(input_shape, DataType::BFLOAT16);
Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::H);
Tensor min_tensor = zeros_like(input_a, output_mem_config);
min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::H, output_mem_config);
min_val.deallocate();
Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config);
min_tensor.deallocate();
Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor midx = full_like(min_indices, size);
Tensor result = where(eqz(min_indices), midx, min_indices);
min_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::H);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
Expand All @@ -1426,10 +1454,14 @@ Tensor _argmin(const Tensor& input_a, int64_t _dim, const MemoryConfig& output_m
Tensor min_val = reduce(input_a, ReduceOpMath::MIN, ReduceOpDim::HW);
Tensor min_tensor = zeros_like(input_a, output_mem_config);
min_tensor = bcast(min_tensor, min_val, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
min_val.deallocate();
Tensor cmp_results = eq(input_a, min_tensor, std::nullopt, output_mem_config);
min_tensor.deallocate();
Tensor min_indices = mul(cmp_results ,tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor result = full_like(min_indices, size);
result = where(eqz(min_indices), result, min_indices);
min_indices.deallocate();
result = reduce(result, ReduceOpMath::MIN, ReduceOpDim::HW);
Tensor res_index = zeros_like(result, output_mem_config);
result = where(eq(result, full_like(result, size)), res_index, result);
Expand Down
6 changes: 2 additions & 4 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,14 @@ namespace tt::tt_metal::detail{
detail::bind_unary_op_with_param(
m_tensor, "argmax", &argmax,
py::arg("dim"),
R"doc(Returns the indices of the maximum value of all elements in the ``input`` tensor.)
Currently supported dimension are H, W, and HW. By default it returns argmax across HW. Suported dimension [1, 1, H, W])doc",
R"doc(Returns the indices of the maximum value of all elements in the ``input`` tensor. Currently supported dimension are H, W, and HW. By default it returns argmax across HW. Suported shapes [1, 1, H, W])doc",
R"doc("dim", "int", "")doc"
);

detail::bind_unary_op_with_param(
m_tensor, "argmin", &argmin,
py::arg("dim"),
R"doc(Returns the indices of the minimum value of all elements in the ``input`` tensor.)
Currently supported dimension are H, W, and HW. By default it returns argmin across HW. Suported dimension [1, 1, H, W])doc",
R"doc(Returns the indices of the minimum value of all elements in the ``input`` tensor. Currently supported dimension are H, W, and HW. By default it returns argmin across HW. Suported shapes [1, 1, H, W])doc",
R"doc("dim", "int", "")doc"
);

Expand Down

0 comments on commit ae373f4

Please sign in to comment.