Skip to content

Commit

Permalink
Add aten_hardtanh_backward function (#1715)
Browse files Browse the repository at this point in the history
Depends on #1707, will add unit test after #1707 merged.
  • Loading branch information
xiaowuhu authored Jul 4, 2024
1 parent 619f5ed commit 0670951
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
return op.Clip(self, min_val, max_val)


@torch_op("aten::hardtanh_backward", trace_only=True)
def aten_hardtanh_backward(
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
) -> TensorType:
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""

raise NotImplementedError()
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
return op.Mul(op.Mul(grad_output, max_mask), min_mask)


def aten_huber_loss(
Expand Down

0 comments on commit 0670951

Please sign in to comment.