Skip to content

Commit

Permalink
#14730: support unequal rank inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Nov 8, 2024
1 parent e601275 commit a3b1db7
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
"shapes",
[
[[1, 71, 7, 7], [7, 7]],
[[920, 1, 256], [256]],
],
)
def test_unequal_ranks(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988


@pytest.mark.parametrize(
"shapes",
[
Expand Down
32 changes: 32 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,38 @@
from torch.nn import functional as F


@pytest.mark.parametrize(
"shapes",
[
[[4, 12, 64, 64], [12, 1, 1]],
[[4, 16, 64, 64], [16, 1, 1]],
[[64, 3, 64, 64], [3, 1, 1]],
[[64, 4, 64, 64], [4, 1, 1]],
[[16, 6, 64, 64], [6, 1, 1]],
[[16, 8, 64, 64], [8, 1, 1]],
[[1, 1], [1, 1, 32]],
],
)
def test_unequal_ranks(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a * torch_input_tensor_b

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.mul(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988


# fmt: off
@pytest.mark.parametrize("scalar", [3.0])
# fmt: on
Expand Down
25 changes: 25 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement/repeat/repeat.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/data_movement/unsqueeze/unsqueeze.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

namespace ttnn::operations::binary {

Expand Down Expand Up @@ -103,6 +105,29 @@ auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_ten
Tensor input_tensor_a = input_tensor_a_arg;
Tensor input_tensor_b = input_tensor_b_arg;

auto rank_a = input_tensor_a.get_shape().rank();
auto rank_b = input_tensor_b.get_shape().rank();
int diff = std::abs((int)rank_a - (int)rank_b);

if(rank_a != rank_b){
if(rank_a > rank_b){
auto s_b = input_tensor_b.get_shape();
std::vector<int32_t> shape_vector(rank_a, 1);
for(int i=0; i < rank_b; ++i){
shape_vector[diff + i] = s_b[i];
}
input_tensor_b = ttnn::reshape(input_tensor_b, shape_vector);
}
if(rank_a < rank_b){
auto s_a = input_tensor_a.get_shape();
std::vector<int32_t> shape_vector(rank_b, 1);
for(int i=0; i < rank_a; ++i){
shape_vector[diff + i] = s_a[i];
}
input_tensor_a = ttnn::reshape(input_tensor_a, shape_vector);
}
}

// TODO: #7731 (Remove calls to repeat )
auto repeat_smaller = [](const auto &first, auto &second) {
const auto first_shape = first.get_shape();
Expand Down

0 comments on commit a3b1db7

Please sign in to comment.