Skip to content

Commit

Permalink
#8691: refactored reduction.py
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 21, 2024
1 parent e1404ac commit 3e5adfd
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 357 deletions.
5 changes: 1 addition & 4 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def test_max_global(device, batch_size, h, w):
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.max(input_tensor)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
output_tensor = output_tensor[0, 0, 0, 0]
output_tensor = output_tensor[0, 0, 0]

assert_with_pcc(torch_output_tensor, output_tensor)
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def test_min_global(device, batch_size, h, w):
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
output_tensor = output_tensor[0, 0, 0, 0]
output_tensor = output_tensor[0, 0, 0]

assert_with_pcc(torch_output_tensor, output_tensor)
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def test_sum_global(device, batch_size, h, w):
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
output_tensor = output_tensor[0, 0, 0, 0]
output_tensor = output_tensor[0, 0, 0]

assert_with_pcc(torch_output_tensor, output_tensor)
Loading

0 comments on commit 3e5adfd

Please sign in to comment.