From 3297052a503a858e2575eec7c50b3f12c45f5089 Mon Sep 17 00:00:00 2001 From: umadevimcw Date: Tue, 26 Mar 2024 14:01:15 +0000 Subject: [PATCH] #6633: Refactor log1p and updates test files --- .../unit_testing/backward_ops/test_backward_log1p.py | 12 ++++++++---- tt_eager/tt_dnn/op_library/backward/backward_ops.cpp | 7 +++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log1p.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log1p.py index ef2a1fc8888..117c2892207 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log1p.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log1p.py @@ -5,7 +5,10 @@ import torch import pytest import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import ( + compare_pcc, + data_gen_with_range, +) @pytest.mark.parametrize( @@ -17,8 +20,9 @@ ), ) def test_bw_log1p(input_shapes, device): - in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True) - grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device) + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 10, device) + tt_output_tensor_on_device = tt_lib.tensor.log1p_bw(grad_tensor, input_tensor) in_data.retain_grad() @@ -28,5 +32,5 @@ def test_bw_log1p(input_shapes, device): pyt_y.backward(gradient=grad_data) golden_tensor = [in_data.grad] - comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index fd9b3a2ddd5..6a11bf6ba06 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -1253,8 +1253,11 @@ std::vector log10_bw(const Tensor& grad, const Tensor& input, const Memo // for -1 = inf std::vector _log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - Tensor grad_a = mul(grad, recip(add1(input, output_mem_config), output_mem_config), std::nullopt, output_mem_config); - grad_a = where(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), std::numeric_limits::infinity(), grad_a, output_mem_config); + Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits::infinity(), std::numeric_limits::infinity(), output_mem_config); + Tensor t_inp1 = add1(input, output_mem_config); + Tensor grad_a = mul(grad, recip(t_inp1, output_mem_config), std::nullopt, output_mem_config); + grad_a = where(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config); + grad_a = where(logical_and(eqz(t_inp1, output_mem_config), eqz(grad, output_mem_config)), std::nanf(" "), grad_a, output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; }