Skip to content

Commit

Permalink
#14862: fp32 support in unary (#14899)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #14862

### Problem description
Provide context for the problem.

### What's changed
Enabled `preserve_fp32_precision` flag  for float32 input 

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/11752673780
https://github.com/tenstorrent/tt-metal/actions/runs/11797734586
- [ ] Nightly fd
https://github.com/tenstorrent/tt-metal/actions/runs/11797739127
- [ ] Model perf
- [ ] Device perf
- [ ] Demo tests
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Nov 12, 2024
1 parent a39f998 commit 0039938
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 3 deletions.
224 changes: 224 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_unary_fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

import pytest
from models.utility_functions import skip_for_grayskull
from tests.ttnn.utils_for_testing import assert_with_pcc


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.neg,
],
)
def test_neg_fp32(device, ttnn_function):
x_torch = torch.tensor([[0.00001]], dtype=torch.float32)
y_torch = -x_torch

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

y_tt = ttnn_function(x_tt)

tt_out = ttnn.to_torch(y_tt)
status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.sin,
],
)
def test_sin_fp32(device, ttnn_function):
x_torch = torch.rand((64, 128), dtype=torch.float32)
y_torch = torch.sin(x_torch)

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

y_tt = ttnn_function(x_tt)

tt_out = ttnn.to_torch(y_tt)
status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.cos,
],
)
def test_cos_fp32(device, ttnn_function):
x_torch = torch.rand((64, 128), dtype=torch.float32)
y_torch = torch.cos(x_torch)

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

y_tt = ttnn_function(x_tt)

tt_out = ttnn.to_torch(y_tt)
status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.tan,
],
)
def test_tan_fp32(device, ttnn_function):
x_torch = torch.rand((64, 128), dtype=torch.float32)
y_torch = torch.tan(x_torch)

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

y_tt = ttnn_function(x_tt)

tt_out = ttnn.to_torch(y_tt)
status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.relu,
],
)
def test_relu_fp32(device, ttnn_function):
x_torch = torch.rand((64, 128), dtype=torch.float32)
y_torch = torch.relu(x_torch)

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)

y_tt = ttnn_function(x_tt)

tt_out = ttnn.to_torch(y_tt)
status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
assert status


def run_unary_test(device, h, w, ttnn_function, pcc=0.9999):
torch.manual_seed(0)

torch_input_tensor = torch.rand((h, w), dtype=torch.float32)
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, pcc)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_exp(device, h, w):
run_unary_test(device, h, w, ttnn.exp, pcc=0.9998)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_tanh(device, h, w):
run_unary_test(device, h, w, ttnn.tanh, pcc=0.993)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_gelu(device, h, w):
run_unary_test(device, h, w, ttnn.gelu, pcc=0.9996)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_rsqrt(device, h, w):
run_unary_test(device, h, w, ttnn.rsqrt)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_silu(device, h, w):
run_unary_test(device, h, w, ttnn.silu)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_log(device, h, w):
run_unary_test(device, h, w, ttnn.log)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_asin(device, h, w):
run_unary_test(device, h, w, ttnn.asin, pcc=0.998)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_acos(device, h, w):
run_unary_test(device, h, w, ttnn.acos, pcc=0.998)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_atan(device, h, w):
run_unary_test(device, h, w, ttnn.atan)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_sinh(device, h, w):
run_unary_test(device, h, w, ttnn.sinh)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_asinh(device, h, w):
run_unary_test(device, h, w, ttnn.asinh, pcc=0.9997)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_cosh(device, h, w):
run_unary_test(device, h, w, ttnn.cosh, pcc=0.999)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_acosh(device, h, w):
run_unary_test(device, h, w, ttnn.acosh)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_atanh(device, h, w):
run_unary_test(device, h, w, ttnn.atanh, pcc=0.997)
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ inline Tensor unary_impl(
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt) {
DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[1]) : input_tensor.get_dtype();
bool preserve_fp32_precision = (op_chain[0].op_type == UnaryOpType::TYPECAST) and (input_tensor.get_dtype() == DataType::FLOAT32);
auto arch = input_tensor.device()->arch();
bool preserve_fp32_precision = (arch != tt::ARCH::GRAYSKULL) and (input_tensor.get_dtype() == DataType::FLOAT32);
bool fp32_dest_acc_en = preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to
// DST directly, fp32 is converted to fp16b
input_tensor.get_dtype() == DataType::INT32;

auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config());
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor);
Expand Down

0 comments on commit 0039938

Please sign in to comment.