Skip to content

Commit

Permalink
Add op(heaviside) | feat(torchlib) (#1068)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Sep 20, 2023
1 parent 3613db6 commit 64189f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
9 changes: 7 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3590,10 +3590,15 @@ def aten_hardshrink_backward(
raise NotImplementedError()


def aten_heaviside(self: TensorType, values: TensorType) -> TensorType:
@torch_op("aten::heaviside")
def aten_heaviside(self: TReal, values: TReal) -> TReal:
"""heaviside(Tensor self, Tensor values) -> Tensor"""

raise NotImplementedError()
zero = op.CastLike(0, self)
one = op.CastLike(1, self)
intermediate = op.Where(op.Less(self, zero), zero, one)

return op.Where(op.Equal(self, zero), values, intermediate)


def aten_hinge_embedding_loss(
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,7 @@ def _where_input_wrangler(
matcher=lambda sample: sample.args[1] == 2,
reason="fixme: 'bicubic' mode in ORT implemented differently with Torch",
),
TorchLibOpInfo("heaviside", core_ops.aten_heaviside),
TorchLibOpInfo(
"hstack",
core_ops.aten_hstack,
Expand Down

0 comments on commit 64189f0

Please sign in to comment.