Skip to content

Commit

Permalink
#6633: Update log2 backward ops
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 27, 2024
1 parent d56461a commit 48cc927
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
compare_pcc,
data_gen_with_range,
)


@pytest.mark.parametrize(
Expand All @@ -17,16 +20,16 @@
),
)
def test_bw_log2(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 10, device)

tt_output_tensor_on_device = tt_lib.tensor.log2_bw(grad_tensor, input_tensor)

in_data.retain_grad()

pyt_y = torch.log2(in_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,10 @@ std::vector<Tensor> ceil_bw(const Tensor& grad, const Tensor& input, const Memor
// bw(log2(in)) = grad/(in * 0.69314718055994530942)
std::vector<Tensor> _log2_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor t_inf = where(ltz(grad, output_mem_config), -std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(), output_mem_config);
Tensor grad_a = mul(grad, recip(mul_unary(input, M_LN2, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), std::nanf(" "),
where(eqz(input, output_mem_config), t_inf, grad_a, output_mem_config), output_mem_config);
grad_tensor.emplace_back(grad_a);
return grad_tensor;
}
Expand Down

0 comments on commit 48cc927

Please sign in to comment.