diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log2.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log2.py index cd11a1b67dd..d44e4b3e5c0 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log2.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_log2.py @@ -5,7 +5,10 @@ import torch import pytest import tt_lib -from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import ( + compare_pcc, + data_gen_with_range, +) @pytest.mark.parametrize( @@ -17,8 +20,9 @@ ), ) def test_bw_log2(input_shapes, device): - in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True) - grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device) + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 10, device) + tt_output_tensor_on_device = tt_lib.tensor.log2_bw(grad_tensor, input_tensor) in_data.retain_grad() @@ -26,7 +30,6 @@ def test_bw_log2(input_shapes, device): pyt_y = torch.log2(in_data) pyt_y.backward(gradient=grad_data) - golden_tensor = [in_data.grad] - comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 0317852a97e..fd9b3a2ddd5 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -1652,7 +1652,10 @@ std::vector ceil_bw(const Tensor& grad, const Tensor& input, const Memor // bw(log2(in)) = grad/(in * 0.69314718055994530942) std::vector _log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; + Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits::infinity(), std::numeric_limits::infinity(), output_mem_config); Tensor grad_a = mul(grad, recip(mul_unary(input, M_LN2, output_mem_config), output_mem_config), std::nullopt, output_mem_config); + grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "), + where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_a); return grad_tensor; }