Skip to content

Commit

Permalink
#6633: Refactor log1p and updates test files
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 27, 2024
1 parent 48cc927 commit 3297052
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
compare_pcc,
data_gen_with_range,
)


@pytest.mark.parametrize(
Expand All @@ -17,8 +20,9 @@
),
)
def test_bw_log1p(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 10, device)

tt_output_tensor_on_device = tt_lib.tensor.log1p_bw(grad_tensor, input_tensor)

in_data.retain_grad()
Expand All @@ -28,5 +32,5 @@ def test_bw_log1p(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
7 changes: 5 additions & 2 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,8 +1253,11 @@ std::vector<Tensor> log10_bw(const Tensor& grad, const Tensor& input, const Memo
// for -1 = inf
std::vector<Tensor> _log1p_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = mul(grad, recip(add1(input, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_a = where(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), std::numeric_limits<float>::infinity(), grad_a, output_mem_config);
Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(), output_mem_config);
Tensor t_inp1 = add1(input, output_mem_config);
Tensor grad_a = mul(grad, recip(t_inp1, output_mem_config), std::nullopt, output_mem_config);
grad_a = where(eq(input, full_like(input, -1.0, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config);
grad_a = where(logical_and(eqz(t_inp1, output_mem_config), eqz(grad, output_mem_config)), std::nanf(" "), grad_a, output_mem_config);
grad_tensor.emplace_back(grad_a);
return grad_tensor;
}
Expand Down

0 comments on commit 3297052

Please sign in to comment.