diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary_fp32.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary_fp32.py new file mode 100644 index 00000000000..86b65b38028 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary_fp32.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn + +import pytest +from models.utility_functions import skip_for_grayskull +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.neg, + ], +) +def test_neg_fp32(device, ttnn_function): + x_torch = torch.tensor([[0.00001]], dtype=torch.float32) + y_torch = -x_torch + + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + + y_tt = ttnn_function(x_tt) + + tt_out = ttnn.to_torch(y_tt) + status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.sin, + ], +) +def test_sin_fp32(device, ttnn_function): + x_torch = torch.rand((64, 128), dtype=torch.float32) + y_torch = torch.sin(x_torch) + + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + + y_tt = ttnn_function(x_tt) + + tt_out = ttnn.to_torch(y_tt) + status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.cos, + ], +) +def test_cos_fp32(device, ttnn_function): + x_torch = torch.rand((64, 128), dtype=torch.float32) + y_torch = torch.cos(x_torch) + + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + + y_tt = ttnn_function(x_tt) + + tt_out = ttnn.to_torch(y_tt) + status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.tan, + ], +) +def test_tan_fp32(device, ttnn_function): + x_torch = torch.rand((64, 128), dtype=torch.float32) + y_torch = torch.tan(x_torch) + + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + + y_tt = ttnn_function(x_tt) + + tt_out = ttnn.to_torch(y_tt) + status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.relu, + ], +) +def test_relu_fp32(device, ttnn_function): + x_torch = torch.rand((64, 128), dtype=torch.float32) + y_torch = torch.relu(x_torch) + + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + + y_tt = ttnn_function(x_tt) + + tt_out = ttnn.to_torch(y_tt) + status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +def run_unary_test(device, h, w, ttnn_function, pcc=0.9999): + torch.manual_seed(0) + + torch_input_tensor = torch.rand((h, w), dtype=torch.float32) + golden_function = ttnn.get_golden_function(ttnn_function) + torch_output_tensor = golden_function(torch_input_tensor, device=device) + + input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn_function(input_tensor) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, pcc) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_exp(device, h, w): + run_unary_test(device, h, w, ttnn.exp, pcc=0.9998) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_tanh(device, h, w): + run_unary_test(device, h, w, ttnn.tanh, pcc=0.993) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_gelu(device, h, w): + run_unary_test(device, h, w, ttnn.gelu, pcc=0.9996) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_rsqrt(device, h, w): + run_unary_test(device, h, w, ttnn.rsqrt) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_silu(device, h, w): + run_unary_test(device, h, w, ttnn.silu) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_log(device, h, w): + run_unary_test(device, h, w, ttnn.log) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_asin(device, h, w): + run_unary_test(device, h, w, ttnn.asin, pcc=0.998) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_acos(device, h, w): + run_unary_test(device, h, w, ttnn.acos, pcc=0.998) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_atan(device, h, w): + run_unary_test(device, h, w, ttnn.atan) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_sinh(device, h, w): + run_unary_test(device, h, w, ttnn.sinh) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_asinh(device, h, w): + run_unary_test(device, h, w, ttnn.asinh, pcc=0.9997) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_cosh(device, h, w): + run_unary_test(device, h, w, ttnn.cosh, pcc=0.999) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_acosh(device, h, w): + run_unary_test(device, h, w, ttnn.acosh) + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_atanh(device, h, w): + run_unary_test(device, h, w, ttnn.atanh, pcc=0.997) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index f661b1cfedd..7a40003fa52 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -23,14 +23,14 @@ inline Tensor unary_impl( const std::optional& memory_config = std::nullopt, const std::optional& optional_output_tensor = std::nullopt) { DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(op_chain[0].params[1]) : input_tensor.get_dtype(); - bool preserve_fp32_precision = (op_chain[0].op_type == UnaryOpType::TYPECAST) and (input_tensor.get_dtype() == DataType::FLOAT32); + auto arch = input_tensor.device()->arch(); + bool preserve_fp32_precision = (arch != tt::ARCH::GRAYSKULL) and (input_tensor.get_dtype() == DataType::FLOAT32); bool fp32_dest_acc_en = preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or - input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to - // DST directly, fp32 is converted to fp16b + input_tensor.get_dtype() == DataType::INT32; auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config()); return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor);