Skip to content

Commit

Permalink
#14020: Update softshrink_bw with golden function (#13982)
Browse files Browse the repository at this point in the history
* #14020: Update softshrink_bw with golden function

* #14020: Update golden function to pass args directly

* #14020: Update golden function

* #14020: Update golden function
  • Loading branch information
VirdhatchaniKN authored Oct 25, 2024
1 parent df85fd3 commit 2ad5a1a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,11 @@
def test_bw_softshrink(input_shapes, lambd, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device)
in_data.retain_grad()

pyt_y = torch.nn.functional.softshrink(in_data, lambd=lambd)

tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor, lambd=lambd)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.softshrink_bw)
golden_tensor = golden_function(grad_data, in_data, lambd)

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Expand All @@ -48,15 +44,11 @@ def test_bw_softshrink(input_shapes, lambd, device):
def test_bw_softshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device)
in_data.retain_grad()

pyt_y = torch.nn.functional.softshrink(in_data)

tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
golden_function = ttnn.get_golden_function(ttnn.softshrink_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
45 changes: 13 additions & 32 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,8 @@ def _golden_function_div_no_nan(torch_op, grad_tensor, input_tensor, alpha, *arg
return golden_tensor


def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha=None, *args, **kwargs):
if torch_op == "leaky_relu":
if alpha != None:
pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha)
else:
pyt_y = torch.nn.functional.leaky_relu(input_tensor)
elif torch_op == "elu":
if alpha != None:
pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha)
else:
pyt_y = torch.nn.functional.elu(input_tensor)
elif torch_op == "celu":
if alpha != None:
pyt_y = torch.nn.functional.celu(input_tensor, alpha)
else:
pyt_y = torch.nn.functional.celu(input_tensor)
else:
if alpha != None:
pyt_y = torch_op(input_tensor, alpha)
else:
pyt_y = torch_op(input_tensor)
def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, *args, **kwargs):
pyt_y = torch_op(input_tensor, *args, **kwargs)
input_tensor.retain_grad()
pyt_y.backward(gradient=grad_tensor)
golden_tensor = [input_tensor.grad]
Expand Down Expand Up @@ -165,36 +146,36 @@ def _golden_function_backward_with_reverse_string(

ttnn.attach_golden_function(
ttnn.hardshrink_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.hardshrink, grad, input, alpha, *args, **kwargs
golden_function=lambda grad, input, alpha=0.5, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.hardshrink, grad, input, lambd=alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.softshrink_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.softshrink, grad, input, alpha, *args, **kwargs
golden_function=lambda grad, input, alpha=0.5, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.nn.functional.softshrink, grad, input, lambd=alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.leaky_relu_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"leaky_relu", grad, input, alpha, *args, **kwargs
golden_function=lambda grad, input, alpha=1e-2, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.nn.functional.leaky_relu, grad, input, negative_slope=alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.elu_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"elu", grad, input, alpha, *args, **kwargs
golden_function=lambda grad, input, alpha=1.0, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.nn.functional.elu, grad, input, alpha=alpha, *args, **kwargs
),
)

ttnn.attach_golden_function(
ttnn.celu_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
"celu", grad, input, alpha, *args, **kwargs
golden_function=lambda grad, input, alpha=1.0, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.nn.functional.celu, grad, input, alpha=alpha, *args, **kwargs
),
)

Expand All @@ -208,7 +189,7 @@ def _golden_function_backward_with_reverse_string(
ttnn.attach_golden_function(
ttnn.logiteps_bw,
golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float(
torch.logit, grad, input, alpha, *args, **kwargs
torch.logit, grad, input, eps=alpha, *args, **kwargs
),
)

Expand Down

0 comments on commit 2ad5a1a

Please sign in to comment.