Skip to content

Commit

Permalink
#5082: Pow gradient calculation method is different with pytorch
Browse files Browse the repository at this point in the history
- update to be the same as PT
- add zero-exponent test
  • Loading branch information
Muthu authored and muthutt committed Feb 2, 2024
1 parent 427f16a commit 2aee577
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,50 @@
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results


@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
@pytest.mark.parametrize(
"exponent",
[
-0.01,
-1.0,
],
)
def test_negative_exponent(input_shapes, exponent, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

with pytest.raises(RuntimeError) as _e:
tt_output_tensor_on_device = tt_lib.tensor.unary_pow_bw(grad_tensor, input_tensor, exponent=exponent)
assert "exponent >= 0.0" in str(_e)


@pytest.mark.parametrize(
"input_shapes",
((torch.Size([1, 1, 32, 32])),),
)
@pytest.mark.parametrize(
"exponent",
[
0,
],
)
def test_fw_exponent(input_shapes, exponent, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

golden_tensor = [
torch.pow(grad_data, 0.0),
]
tt_output_tensor_on_device = tt_lib.tensor.pow(grad_tensor, 0.0)
status = compare_results([tt_output_tensor_on_device], golden_tensor)
assert status

# assert "exponent >= 0.0" in str(_e)


@pytest.mark.parametrize(
"input_shapes",
(
Expand All @@ -22,6 +66,7 @@
0.0,
1.0,
2.0,
5.0,
],
)
def test_bw_unary_pow(input_shapes, exponent, device):
Expand Down
7 changes: 7 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ std::vector<Tensor> unary_mul_bw(const Tensor& grad, const Tensor& input, float

std::vector<Tensor> _unary_pow_bw(const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
const float ZERO_THRESHOLD = std::numeric_limits<float>::epsilon()*10.0f;
TT_FATAL(exponent >= 0.0, "negative exponents are not supported; use recip(pow(input,abs(exponent)))");
if ( std::abs(exponent) < ZERO_THRESHOLD ) {
grad_tensor.emplace_back( zeros_like( input, output_mem_config) );
return grad_tensor;
}

Tensor power_input = power(input, exponent - 1, output_mem_config);

Tensor result = mul_unary(power_input, exponent, output_mem_config);
Expand Down

0 comments on commit 2aee577

Please sign in to comment.